mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-29 11:19:56 +00:00
Compare commits
36 Commits
feat/admin
...
peer-acl-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
959cfa1804 | ||
|
|
2d7b309004 | ||
|
|
5968cff242 | ||
|
|
cf43841b86 | ||
|
|
739e36a313 | ||
|
|
2bb5421631 | ||
|
|
998ade6e6d | ||
|
|
62f5467cd8 | ||
|
|
1b29995ece | ||
|
|
fd96b8c12f | ||
|
|
6dd6c3f398 | ||
|
|
d1422dcf09 | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a | ||
|
|
6c26178ad5 | ||
|
|
97d9559e6d | ||
|
|
ab7639d101 | ||
|
|
a5d4373ddc | ||
|
|
1b6294f2ff | ||
|
|
a14586b142 | ||
|
|
09c0063d71 | ||
|
|
1c5b84b1a1 | ||
|
|
1ed2067b8b | ||
|
|
f7e9df6ffa | ||
|
|
38603c7552 |
@@ -64,7 +64,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: true
|
cache: true
|
||||||
|
|||||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -21,13 +21,13 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
|
|||||||
20
.github/workflows/golang-test-freebsd.yml
vendored
20
.github/workflows/golang-test-freebsd.yml
vendored
@@ -48,14 +48,14 @@ jobs:
|
|||||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -tags privileged -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -tags privileged -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -tags privileged -timeout 1m -failfast ./encryption/...
|
||||||
time go test -timeout 1m -failfast ./formatter/...
|
time go test -tags privileged -timeout 1m -failfast ./formatter/...
|
||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -tags privileged -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -tags privileged -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -tags privileged -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
51
.github/workflows/golang-test-linux.yml
vendored
51
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- 'management/**'
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -135,7 +135,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -158,7 +158,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
@@ -168,7 +168,6 @@ jobs:
|
|||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
flags: unit,client
|
flags: unit,client
|
||||||
|
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -180,7 +179,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -192,7 +191,7 @@ jobs:
|
|||||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache-restore
|
id: cache-restore
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -229,7 +228,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -251,7 +250,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -266,7 +265,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -311,7 +310,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -325,7 +324,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -368,7 +367,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -383,7 +382,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -429,7 +428,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -440,7 +439,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -534,7 +533,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -545,7 +544,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -579,10 +578,11 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
@@ -628,7 +628,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -639,7 +639,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -673,12 +673,13 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/server/http/...
|
-timeout 20m ./management/server/http/...
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
name: "Management / Integration"
|
name: "Management / Integration"
|
||||||
@@ -697,7 +698,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -708,7 +709,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
|
|||||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
@@ -35,7 +35,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
@@ -28,13 +28,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
cmdline-tools-version: 8512546
|
cmdline-tools-version: 8512546
|
||||||
- name: Setup Java
|
- name: Setup Java
|
||||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||||
with:
|
with:
|
||||||
java-version: "11"
|
java-version: "11"
|
||||||
distribution: "adopt"
|
distribution: "adopt"
|
||||||
- name: NDK Cache
|
- name: NDK Cache
|
||||||
id: ndk-cache
|
id: ndk-cache
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: /usr/local/lib/android/sdk/ndk
|
path: /usr/local/lib/android/sdk/ndk
|
||||||
key: ndk-cache-23.1.7779620
|
key: ndk-cache-23.1.7779620
|
||||||
@@ -58,7 +58,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
|
|||||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -166,12 +166,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -374,12 +374,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -469,12 +469,12 @@ jobs:
|
|||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
|||||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
|
|||||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: lint lint-all lint-install setup-hooks
|
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
# Install golangci-lint locally if needed
|
# Install golangci-lint locally if needed
|
||||||
@@ -25,3 +25,15 @@ setup-hooks:
|
|||||||
@git config core.hooksPath .githooks
|
@git config core.hooksPath .githooks
|
||||||
@chmod +x .githooks/pre-push
|
@chmod +x .githooks/pre-push
|
||||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
|
|
||||||
|
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||||
|
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||||
|
test-unit:
|
||||||
|
@go test -tags devcert -timeout 10m ./...
|
||||||
|
|
||||||
|
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||||
|
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||||
|
# Narrow the run with env vars, e.g.:
|
||||||
|
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||||
|
test-privileged:
|
||||||
|
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||||
|
|||||||
@@ -37,6 +37,11 @@
|
|||||||
</strong>
|
</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
> ### 🤖 NetBird Agent Network (Beta)
|
||||||
|
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||||
|
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||||
|
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||||
|
|
||||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|||||||
39
agent-network/README.md
Normal file
39
agent-network/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# NetBird Agent Network
|
||||||
|
|
||||||
|
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||||
|
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||||
|
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||||
|
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||||
|
policy, with no API keys to leak.
|
||||||
|
|
||||||
|
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||||
|
> infrastructure.
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
Agent Network is built on two existing NetBird capabilities:
|
||||||
|
|
||||||
|
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||||
|
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||||
|
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||||
|
key server-side, forwards to the API or gateway, and records usage.
|
||||||
|
|
||||||
|
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||||
|
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||||
|
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||||
|
|
||||||
|
## Where the code lives
|
||||||
|
|
||||||
|
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||||
|
components:
|
||||||
|
|
||||||
|
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||||
|
and runs the per-request middleware pipeline.
|
||||||
|
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||||
|
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||||
|
and usage/access logs.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Full documentation, architecture, and quickstart:
|
||||||
|
**https://docs.netbird.io/agent-network**
|
||||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||||
ProfileName: activeProf.Name,
|
ProfileName: string(activeProf.ID),
|
||||||
Username: currUser.Username,
|
Username: currUser.Username,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
196
client/cmd/service_privileged_test.go
Normal file
196
client/cmd/service_privileged_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/kardianos/service"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceStartTimeout = 10 * time.Second
|
||||||
|
serviceStopTimeout = 5 * time.Second
|
||||||
|
statusPollInterval = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||||
|
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer timeoutCancel()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(statusPollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||||
|
case <-ticker.C:
|
||||||
|
status, err := s.Status()
|
||||||
|
if err != nil {
|
||||||
|
// Continue polling on transient errors
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if status == expectedStatus {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceLifecycle tests the complete service lifecycle
|
||||||
|
func TestServiceLifecycle(t *testing.T) {
|
||||||
|
// TODO: Add support for Windows and macOS
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Getenv("CONTAINER") == "true" {
|
||||||
|
t.Skip("Skipping service lifecycle test in container environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
originalServiceName := serviceName
|
||||||
|
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||||
|
defer func() {
|
||||||
|
serviceName = originalServiceName
|
||||||
|
}()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||||
|
|
||||||
|
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service config: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the subtests already cleaned up, there's nothing to do.
|
||||||
|
if _, err := s.Status(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Stop(); err != nil {
|
||||||
|
t.Errorf("cleanup: stop service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Uninstall(); err != nil {
|
||||||
|
t.Errorf("cleanup: uninstall service: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("Install", func(t *testing.T) {
|
||||||
|
installCmd.SetContext(ctx)
|
||||||
|
err := installCmd.RunE(installCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
status, err := s.Status()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, service.StatusUnknown, status)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Start", func(t *testing.T) {
|
||||||
|
startCmd.SetContext(ctx)
|
||||||
|
err := startCmd.RunE(startCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Restart", func(t *testing.T) {
|
||||||
|
restartCmd.SetContext(ctx)
|
||||||
|
err := restartCmd.RunE(restartCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Reconfigure", func(t *testing.T) {
|
||||||
|
originalLogLevel := logLevel
|
||||||
|
logLevel = "debug"
|
||||||
|
defer func() {
|
||||||
|
logLevel = originalLogLevel
|
||||||
|
}()
|
||||||
|
|
||||||
|
reconfigureCmd.SetContext(ctx)
|
||||||
|
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stop", func(t *testing.T) {
|
||||||
|
stopCmd.SetContext(ctx)
|
||||||
|
err := stopCmd.RunE(stopCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, stopped)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Uninstall", func(t *testing.T) {
|
||||||
|
uninstallCmd.SetContext(ctx)
|
||||||
|
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.Status()
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
serviceStartTimeout = 10 * time.Second
|
|
||||||
serviceStopTimeout = 5 * time.Second
|
|
||||||
statusPollInterval = 500 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
|
||||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer timeoutCancel()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(statusPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
|
||||||
case <-ticker.C:
|
|
||||||
status, err := s.Status()
|
|
||||||
if err != nil {
|
|
||||||
// Continue polling on transient errors
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if status == expectedStatus {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceLifecycle tests the complete service lifecycle
|
|
||||||
func TestServiceLifecycle(t *testing.T) {
|
|
||||||
// TODO: Add support for Windows and macOS
|
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
|
||||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.Getenv("CONTAINER") == "true" {
|
|
||||||
t.Skip("Skipping service lifecycle test in container environment")
|
|
||||||
}
|
|
||||||
|
|
||||||
originalServiceName := serviceName
|
|
||||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
|
||||||
defer func() {
|
|
||||||
serviceName = originalServiceName
|
|
||||||
}()
|
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
|
||||||
logLevel = "info"
|
|
||||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
|
||||||
|
|
||||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service config: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the subtests already cleaned up, there's nothing to do.
|
|
||||||
if _, err := s.Status(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Stop(); err != nil {
|
|
||||||
t.Errorf("cleanup: stop service: %v", err)
|
|
||||||
}
|
|
||||||
if err := s.Uninstall(); err != nil {
|
|
||||||
t.Errorf("cleanup: uninstall service: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
t.Run("Install", func(t *testing.T) {
|
|
||||||
installCmd.SetContext(ctx)
|
|
||||||
err := installCmd.RunE(installCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
status, err := s.Status()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotEqual(t, service.StatusUnknown, status)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Start", func(t *testing.T) {
|
|
||||||
startCmd.SetContext(ctx)
|
|
||||||
err := startCmd.RunE(startCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Restart", func(t *testing.T) {
|
|
||||||
restartCmd.SetContext(ctx)
|
|
||||||
err := restartCmd.RunE(restartCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Reconfigure", func(t *testing.T) {
|
|
||||||
originalLogLevel := logLevel
|
|
||||||
logLevel = "debug"
|
|
||||||
defer func() {
|
|
||||||
logLevel = originalLogLevel
|
|
||||||
}()
|
|
||||||
|
|
||||||
reconfigureCmd.SetContext(ctx)
|
|
||||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Stop", func(t *testing.T) {
|
|
||||||
stopCmd.SetContext(ctx)
|
|
||||||
err := stopCmd.RunE(stopCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, stopped)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Uninstall", func(t *testing.T) {
|
|
||||||
uninstallCmd.SetContext(ctx)
|
|
||||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.Status()
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceEnvVars tests environment variable parsing
|
// TestServiceEnvVars tests environment variable parsing
|
||||||
func TestServiceEnvVars(t *testing.T) {
|
func TestServiceEnvVars(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
11
client/firewall/allower_other.go
Normal file
11
client/firewall/allower_other.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build android || (!linux && !windows)
|
||||||
|
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
|
||||||
|
// interfaceAllower returns no allower: these platforms have no host firewall to
|
||||||
|
// open for the interface.
|
||||||
|
func interfaceAllower(IFaceMapper, uint16) uspfilter.InterfaceAllower {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
10
client/firewall/allower_windows.go
Normal file
10
client/firewall/allower_windows.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
|
||||||
|
// interfaceAllower returns the Windows netsh-based interface allower.
|
||||||
|
func interfaceAllower(iface IFaceMapper, _ uint16) uspfilter.InterfaceAllower {
|
||||||
|
return uspfilter.NewWindowsInterfaceAllower(iface)
|
||||||
|
}
|
||||||
@@ -6,8 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
@@ -21,13 +19,11 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
return uspfilter.Create(uspfilter.Config{
|
||||||
if err != nil {
|
IFace: iface,
|
||||||
return nil, err
|
DisableServerRoutes: disableServerRoutes,
|
||||||
}
|
FlowLogger: flowLogger,
|
||||||
err = fm.AllowNetbird()
|
MTU: mtu,
|
||||||
if err != nil {
|
InterfaceAllower: interfaceAllower(iface, mtu),
|
||||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
})
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -29,47 +30,107 @@ const (
|
|||||||
NFTABLES
|
NFTABLES
|
||||||
)
|
)
|
||||||
|
|
||||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
// SkipNftablesEnv is the environment variable to skip nftables check
|
||||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
|
||||||
|
|
||||||
|
// errNoFirewallManager indicates no kernel firewall backend is present,
|
||||||
|
// as opposed to a backend that exists but failed to create or initialize.
|
||||||
|
var errNoFirewallManager = errors.New("no firewall manager found")
|
||||||
|
|
||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
// Userspace firewall without a native counterpart: routing is handled
|
||||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
// entirely in userspace. The interface is opened in the kernel's foreign
|
||||||
log.Info("forcing userspace firewall")
|
// filter chains via a table-less allower, except in netstack mode where no
|
||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
// kernel interface exists.
|
||||||
|
if netstack.IsEnabled() || (iface.IsUserspaceBind() && forceUserspaceFirewall()) {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Info("netstack mode, using userspace firewall")
|
||||||
|
} else {
|
||||||
|
log.Info("forcing userspace firewall")
|
||||||
|
}
|
||||||
|
cfg := uspfilter.Config{
|
||||||
|
IFace: iface,
|
||||||
|
DisableServerRoutes: disableServerRoutes,
|
||||||
|
FlowLogger: flowLogger,
|
||||||
|
MTU: mtu,
|
||||||
|
InterfaceAllower: interfaceAllower(iface, mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
return uspfilter.Create(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
fm, err := createNativeFirewall(iface, stateManager, mtu)
|
||||||
|
switch {
|
||||||
// Kernel cannot fall back to anything else, need to return error
|
case err == nil && !iface.IsUserspaceBind():
|
||||||
if !iface.IsUserspaceBind() {
|
// Nothing to do, fall through
|
||||||
return fm, err
|
case err == nil && iface.IsUserspaceBind():
|
||||||
}
|
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||||
|
// needs a device filter for DNS interception hooks. Install a minimal
|
||||||
// Fall back to the userspace packet filter if native is unavailable
|
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||||
if err != nil {
|
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
}
|
||||||
}
|
case err != nil && !iface.IsUserspaceBind():
|
||||||
|
// Kernel cannot fall back to anything else, need to return error
|
||||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
return nil, err
|
||||||
// needs a device filter for DNS interception hooks. Install a minimal
|
case err != nil && iface.IsUserspaceBind():
|
||||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
// Fall back to the userspace packet filter if native is unavailable
|
||||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
logNativeFirewallUnavailable(err)
|
||||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
return uspfilter.Create(uspfilter.Config{
|
||||||
|
IFace: iface,
|
||||||
|
DisableServerRoutes: disableServerRoutes,
|
||||||
|
FlowLogger: flowLogger,
|
||||||
|
MTU: mtu,
|
||||||
|
InterfaceAllower: interfaceAllower(iface, mtu),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
// interfaceAllower selects how the userspace firewall opens the interface in
|
||||||
|
// foreign kernel chains: nftables when available (which also opens foreign nft
|
||||||
|
// tables), else iptables (the legacy fallback, filter INPUT only), else nil.
|
||||||
|
// firewalld trust is applied separately by the manager. Netstack has no kernel
|
||||||
|
// interface to open.
|
||||||
|
func interfaceAllower(iface IFaceMapper, mtu uint16) uspfilter.InterfaceAllower {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nftAllower, err := nbnftables.NewInterfaceAllower(iface, mtu)
|
||||||
|
if err == nil {
|
||||||
|
return nftAllower
|
||||||
|
}
|
||||||
|
log.Infof("no nftables interface allower: %v", err)
|
||||||
|
|
||||||
|
iptAllower, err := nbiptables.NewInterfaceAllower(iface)
|
||||||
|
if err == nil {
|
||||||
|
return iptAllower
|
||||||
|
}
|
||||||
|
log.Infof("no iptables interface allower: %v", err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// logNativeFirewallUnavailable logs the fallback to userspace at info level
|
||||||
|
// when no kernel firewall backend exists, and at warn level otherwise.
|
||||||
|
func logNativeFirewallUnavailable(err error) {
|
||||||
|
if errors.Is(err, errNoFirewallManager) {
|
||||||
|
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
|
||||||
|
} else {
|
||||||
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, mtu uint16) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface, mtu)
|
fm, err := createFW(iface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fm.Init(stateManager); err != nil {
|
if err = fm.Init(stateManager); err != nil {
|
||||||
@@ -88,29 +149,10 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
|||||||
log.Info("creating an nftables firewall manager")
|
log.Info("creating an nftables firewall manager")
|
||||||
return nbnftables.Create(iface, mtu)
|
return nbnftables.Create(iface, mtu)
|
||||||
default:
|
default:
|
||||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
return nil, errNoFirewallManager
|
||||||
return nil, errors.New("no firewall manager found")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
|
||||||
var errUsp error
|
|
||||||
if fm != nil {
|
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
|
||||||
} else {
|
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
if errUsp != nil {
|
|
||||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
func check() FWType {
|
func check() FWType {
|
||||||
useIPTABLES := false
|
useIPTABLES := false
|
||||||
@@ -132,35 +174,38 @@ func check() FWType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nf := nftables.Conn{}
|
// Honor the skip env before probing nftables at all.
|
||||||
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
if os.Getenv(SkipNftablesEnv) != "true" {
|
||||||
if !useIPTABLES {
|
nf := nftables.Conn{}
|
||||||
return NFTABLES
|
if chains, err := nf.ListChains(); err == nil {
|
||||||
}
|
if !useIPTABLES {
|
||||||
|
|
||||||
// search for chains where table is filter
|
|
||||||
// if we find one, we assume that nftables manager can be used with iptables
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name == "filter" {
|
|
||||||
return NFTABLES
|
return NFTABLES
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// check tables for the following constraints:
|
// search for chains where table is filter
|
||||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
// if we find one, we assume that nftables manager can be used with iptables
|
||||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
for _, chain := range chains {
|
||||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
if chain.Table.Name == "filter" {
|
||||||
// 4. if we find an error we log and continue with iptables check
|
return NFTABLES
|
||||||
nbTablesList, err := nf.ListTables()
|
}
|
||||||
switch {
|
}
|
||||||
case err == nil && len(iptablesChains) > 0:
|
|
||||||
return IPTABLES
|
// check tables for the following constraints:
|
||||||
case err == nil && len(nbTablesList) != 1:
|
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||||
return NFTABLES
|
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||||
return IPTABLES
|
// 4. if we find an error we log and continue with iptables check
|
||||||
case err != nil:
|
nbTablesList, err := nf.ListTables()
|
||||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
switch {
|
||||||
|
case err == nil && len(iptablesChains) > 0:
|
||||||
|
return IPTABLES
|
||||||
|
case err == nil && len(nbTablesList) != 1:
|
||||||
|
return NFTABLES
|
||||||
|
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||||
|
return IPTABLES
|
||||||
|
case err != nil:
|
||||||
|
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,15 +221,21 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceUserspaceFirewall reports whether the userspace firewall is forced.
|
||||||
|
// NB_FORCE_USERSPACE_ROUTER is an alias: forcing userspace routing implies the
|
||||||
|
// userspace firewall, since the two are no longer separable.
|
||||||
func forceUserspaceFirewall() bool {
|
func forceUserspaceFirewall() bool {
|
||||||
val := os.Getenv(EnvForceUserspaceFirewall)
|
return envForceBool(EnvForceUserspaceFirewall) || envForceBool(uspfilter.EnvForceUserspaceRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func envForceBool(name string) bool {
|
||||||
|
val := os.Getenv(name)
|
||||||
if val == "" {
|
if val == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
force, err := strconv.ParseBool(val)
|
force, err := strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
log.Warnf("failed to parse %s: %v", name, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return force
|
return force
|
||||||
|
|||||||
@@ -1,560 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"maps"
|
|
||||||
"net"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
ipset "github.com/lrh3321/ipset-go"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
tableName = "filter"
|
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
|
||||||
|
|
||||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
|
||||||
// external DNAT from bypassing ACL rules.
|
|
||||||
mangleFwdKey = "MANGLE-FORWARD"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
|
||||||
|
|
||||||
type entry struct {
|
|
||||||
spec []string
|
|
||||||
position int
|
|
||||||
}
|
|
||||||
|
|
||||||
type aclManager struct {
|
|
||||||
iptablesClient *iptables.IPTables
|
|
||||||
wgIface iFaceMapper
|
|
||||||
entries aclEntries
|
|
||||||
optionalEntries map[string][]entry
|
|
||||||
ipsetStore *ipsetStore
|
|
||||||
v6 bool
|
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
|
||||||
return &aclManager{
|
|
||||||
iptablesClient: iptablesClient,
|
|
||||||
wgIface: wgIface,
|
|
||||||
entries: make(map[string][][]string),
|
|
||||||
optionalEntries: make(map[string][]entry),
|
|
||||||
ipsetStore: newIpsetStore(),
|
|
||||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|
||||||
m.stateManager = stateManager
|
|
||||||
|
|
||||||
m.seedInitialEntries()
|
|
||||||
m.seedInitialOptionalEntries()
|
|
||||||
|
|
||||||
if err := m.cleanChains(); err != nil {
|
|
||||||
return fmt.Errorf("clean chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.createDefaultChains(); err != nil {
|
|
||||||
return fmt.Errorf("create default chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
|
||||||
id []byte,
|
|
||||||
ip net.IP,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
chain := chainNameInputRules
|
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
|
||||||
if m.v6 && ipsetName != "" {
|
|
||||||
ipsetName += "-v6"
|
|
||||||
}
|
|
||||||
proto := protoForFamily(protocol, m.v6)
|
|
||||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
|
|
||||||
mangleSpecs := slices.Clone(specs)
|
|
||||||
mangleSpecs = append(mangleSpecs,
|
|
||||||
"-i", m.wgIface.Name(),
|
|
||||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
|
||||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
|
||||||
)
|
|
||||||
|
|
||||||
specs = append(specs, "-j", actionToStr(action))
|
|
||||||
if ipsetName != "" {
|
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
|
||||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
|
||||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
|
||||||
ipList.addIP(ip.String())
|
|
||||||
return []firewall.Rule{&Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: chain,
|
|
||||||
specs: specs,
|
|
||||||
v6: m.v6,
|
|
||||||
}}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
|
||||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := m.createIPSet(ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("create ipset: %w", err)
|
|
||||||
}
|
|
||||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
|
||||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipList := newIpList(ip.String())
|
|
||||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
// Insert at the beginning of the chain (position 1)
|
|
||||||
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
|
|
||||||
} else {
|
|
||||||
err = m.iptablesClient.Append(tableFilter, chain, specs...)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
|
||||||
log.Errorf("failed to add mangle rule: %v", err)
|
|
||||||
mangleSpecs = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
specs: specs,
|
|
||||||
mangleSpecs: mangleSpecs,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: chain,
|
|
||||||
v6: m.v6,
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateState()
|
|
||||||
|
|
||||||
return []firewall.Rule{rule}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
|
||||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldDestroyIpset := false
|
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
|
||||||
// delete IP from ruleset IPs list and ipset
|
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
|
||||||
ip := net.ParseIP(r.ip)
|
|
||||||
if ip == nil {
|
|
||||||
return fmt.Errorf("parse IP %s", r.ip)
|
|
||||||
}
|
|
||||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
|
||||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
|
||||||
}
|
|
||||||
delete(ipsetList.ips, r.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if after delete, set still contains other IPs,
|
|
||||||
// no need to delete firewall rule and we should exit here
|
|
||||||
if len(ipsetList.ips) != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we delete last IP from the set, that means we need to delete
|
|
||||||
// set itself and associated firewall rule too
|
|
||||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
|
||||||
shouldDestroyIpset = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.mangleSpecs != nil {
|
|
||||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldDestroyIpset {
|
|
||||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
|
||||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
|
||||||
log.Debugf("destroy empty ipset: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("destroy empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) Reset() error {
|
|
||||||
if err := m.cleanChains(); err != nil {
|
|
||||||
return fmt.Errorf("clean chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
|
||||||
func (m *aclManager) cleanChains() error {
|
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
for _, rule := range m.entries["INPUT"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range m.entries["FORWARD"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list chains: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
for _, rule := range m.entries["PREROUTING"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range m.entries[mangleFwdKey] {
|
|
||||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
|
||||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
|
||||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
|
||||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
|
||||||
} else {
|
|
||||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.ipsetStore.deleteIpset(ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) createDefaultChains() error {
|
|
||||||
// chain netbird-acl-input-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
|
||||||
// mangle FORWARD guard rules are handled separately below
|
|
||||||
if chainName == mangleFwdKey {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, entries := range m.optionalEntries {
|
|
||||||
for _, entry := range entries {
|
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
|
||||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
clear(m.optionalEntries)
|
|
||||||
|
|
||||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
|
||||||
for _, rule := range m.entries[mangleFwdKey] {
|
|
||||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
|
||||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
|
||||||
// We want to make sure our traffic is not dropped by existing rules.
|
|
||||||
|
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
|
||||||
func (m *aclManager) seedInitialEntries() {
|
|
||||||
established := getConntrackEstablished()
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
|
||||||
|
|
||||||
// Inbound is handled by our ACLs, the rest is dropped.
|
|
||||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
|
||||||
|
|
||||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
|
||||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
|
||||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
|
||||||
// ACL mark check where it cannot be overridden.
|
|
||||||
m.appendToEntries(mangleFwdKey, []string{
|
|
||||||
"-i", m.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
|
||||||
"-j", "ACCEPT",
|
|
||||||
})
|
|
||||||
m.appendToEntries(mangleFwdKey, []string{
|
|
||||||
"-i", m.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "DNAT",
|
|
||||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
|
||||||
"-j", "DROP",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
|
||||||
{
|
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
|
||||||
position: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
|
||||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) updateState() {
|
|
||||||
if m.stateManager == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var currentState *ShutdownState
|
|
||||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
|
||||||
if existingState, ok := existing.(*ShutdownState); ok {
|
|
||||||
currentState = existingState
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if currentState == nil {
|
|
||||||
currentState = &ShutdownState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
currentState.Lock()
|
|
||||||
defer currentState.Unlock()
|
|
||||||
|
|
||||||
// Clone the maps so the persisted state holds a private snapshot. The
|
|
||||||
// live maps keep being mutated by subsequent rule operations while the
|
|
||||||
// state manager marshals the state from its periodic-save goroutine.
|
|
||||||
// Sharing them by reference races the two and aborts the process with a
|
|
||||||
// concurrent map iteration and write.
|
|
||||||
if m.v6 {
|
|
||||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
|
||||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
|
||||||
} else {
|
|
||||||
currentState.ACLEntries = maps.Clone(m.entries)
|
|
||||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
|
||||||
log.Errorf("failed to update state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
|
||||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
|
||||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
|
||||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
|
||||||
if v6 && protocol == firewall.ProtocolICMP {
|
|
||||||
return "ipv6-icmp"
|
|
||||||
}
|
|
||||||
return string(protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
|
||||||
// don't use IP matching if IP is 0.0.0.0
|
|
||||||
matchByIP := !ip.IsUnspecified()
|
|
||||||
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-s", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if protocol != "all" {
|
|
||||||
specs = append(specs, "-p", protocol)
|
|
||||||
}
|
|
||||||
specs = append(specs, applyPort("--sport", sPort)...)
|
|
||||||
specs = append(specs, applyPort("--dport", dPort)...)
|
|
||||||
return specs
|
|
||||||
}
|
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
|
||||||
if action == firewall.ActionAccept {
|
|
||||||
return "ACCEPT"
|
|
||||||
}
|
|
||||||
return "DROP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
|
|
||||||
if ipsetName == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
actionSuffix := ""
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
actionSuffix = "-drop"
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case sPort != nil && dPort != nil:
|
|
||||||
return ipsetName + "-sport-dport" + actionSuffix
|
|
||||||
case sPort != nil:
|
|
||||||
return ipsetName + "-sport" + actionSuffix
|
|
||||||
case dPort != nil:
|
|
||||||
return ipsetName + "-dport" + actionSuffix
|
|
||||||
default:
|
|
||||||
return ipsetName + actionSuffix
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) createIPSet(name string) error {
|
|
||||||
opts := ipset.CreateOptions{
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
if m.v6 {
|
|
||||||
opts.Family = ipset.FamilyIPV6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("created ipset %s with type hash:net", name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
|
||||||
cidr := uint8(32)
|
|
||||||
if ip.To4() == nil {
|
|
||||||
cidr = 128
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: cidr,
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
|
||||||
cidr := uint8(32)
|
|
||||||
if ip.To4() == nil {
|
|
||||||
cidr = 128
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: cidr,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Del(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) flushIPSet(name string) error {
|
|
||||||
return ipset.Flush(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *aclManager) destroyIPSet(name string) error {
|
|
||||||
return ipset.Destroy(name)
|
|
||||||
}
|
|
||||||
346
client/firewall/iptables/chains_linux.go
Normal file
346
client/firewall/iptables/chains_linux.go
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) createContainers() error {
|
||||||
|
for _, chainInfo := range []struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
}{
|
||||||
|
{chainRTFwdIn, tableFilter},
|
||||||
|
{chainRTFwdOut, tableFilter},
|
||||||
|
{chainRTPre, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRdr, tableNat},
|
||||||
|
{chainRTMSSClamp, tableMangle},
|
||||||
|
} {
|
||||||
|
// Fallback: clear chains that survived an unclean shutdown.
|
||||||
|
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||||
|
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
|
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
|
return fmt.Errorf("add static nat rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addJumpRules(); err != nil {
|
||||||
|
return fmt.Errorf("add jump rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
|
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addJumpRules() error {
|
||||||
|
// Jump to nat chain
|
||||||
|
natRule := jumpRuleSpec(chainRTNAT)
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
|
||||||
|
return fmt.Errorf("add nat postrouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNATPost] = natRule
|
||||||
|
|
||||||
|
// Jump to mangle prerouting chain
|
||||||
|
preRule := jumpRuleSpec(chainRTPre)
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
|
||||||
|
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpManglePre] = preRule
|
||||||
|
|
||||||
|
// Jump to nat prerouting chain
|
||||||
|
rdrRule := jumpRuleSpec(chainRTRdr)
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
|
||||||
|
return fmt.Errorf("add nat prerouting jump rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNATPre] = rdrRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) setupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
preRule := []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePre] = preRule
|
||||||
|
}
|
||||||
|
|
||||||
|
postRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePost] = postRule
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// seedInitialEntries adds default rules to the entries map. Rules are
|
||||||
|
// inserted at position 1, so the order here is reversed.
|
||||||
|
//
|
||||||
|
// Existing FORWARD policy decides outbound traffic towards our
|
||||||
|
// interface. If FORWARD policy is "drop", we add an
|
||||||
|
// established/related rule to allow return traffic for inbound rules.
|
||||||
|
func (r *family) seedInitialEntries() {
|
||||||
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
|
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||||
|
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
|
||||||
|
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
|
||||||
|
|
||||||
|
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||||
|
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
|
||||||
|
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
|
||||||
|
|
||||||
|
// Mangle FORWARD guard: when external DNAT redirects traffic from
|
||||||
|
// the wg interface, it traverses FORWARD instead of INPUT,
|
||||||
|
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
|
||||||
|
// inserted above ours. Mangle runs before filter, so these guard
|
||||||
|
// rules enforce the ACL mark check where it cannot be overridden.
|
||||||
|
r.appendToEntries(mangleForwardKey, []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
})
|
||||||
|
r.appendToEntries(mangleForwardKey, []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "DNAT",
|
||||||
|
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
"-j", "DROP",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) seedInitialOptionalEntries() {
|
||||||
|
r.optionalEntries[chainForward] = []entry{
|
||||||
|
{
|
||||||
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||||
|
position: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
|
||||||
|
r.entries[chain] = append(r.entries[chain], spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createDefaultChains() error {
|
||||||
|
if err := r.iptablesClient.NewChain(tableFilter, chainACLInput); err != nil {
|
||||||
|
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for chain, rules := range r.entries {
|
||||||
|
// mangle FORWARD guard rules are handled separately below
|
||||||
|
if chain == mangleForwardKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if err := r.iptablesClient.InsertUnique(tableFilter, string(chain), 1, rule...); err != nil {
|
||||||
|
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for chain, entries := range r.optionalEntries {
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := r.iptablesClient.InsertUnique(tableFilter, string(chain), entry.position, entry.spec...); err != nil {
|
||||||
|
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.entries[chain] = append(r.entries[chain], entry.spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clear(r.optionalEntries)
|
||||||
|
|
||||||
|
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||||
|
for _, rule := range r.entries[mangleForwardKey] {
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) cleanUpDefaultForwardRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
|
||||||
|
// the others, so the chain below deletes cleanly instead of failing
|
||||||
|
// with "device or resource busy".
|
||||||
|
if err := r.cleanJumpRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chainInfo := range []struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
}{
|
||||||
|
{chainRTFwdIn, tableFilter},
|
||||||
|
{chainRTFwdOut, tableFilter},
|
||||||
|
{chainRTPre, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRdr, tableNat},
|
||||||
|
{chainNATOutput, tableNat},
|
||||||
|
{chainRTMSSClamp, tableMangle},
|
||||||
|
} {
|
||||||
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) cleanJumpRules() error {
|
||||||
|
// locations maps each jump rule to the built-in table and chain it
|
||||||
|
// was inserted into, plus the netbird chain it targets.
|
||||||
|
locations := map[firewall.RuleID]struct{ table, chain, target string }{
|
||||||
|
jumpNATPost: {tableNat, chainPostrouting, chainRTNAT},
|
||||||
|
jumpManglePre: {tableMangle, chainPrerouting, chainRTPre},
|
||||||
|
jumpNATPre: {tableNat, chainPrerouting, chainRTRdr},
|
||||||
|
jumpMSSClamp: {tableMangle, chainForward, chainRTMSSClamp},
|
||||||
|
jumpNATOutput: {tableNat, chainOutput, chainNATOutput},
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for ruleID, loc := range locations {
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
// Untracked (e.g. fresh start after an unclean shutdown with no
|
||||||
|
// restored state): if the target chain survived, remove the stale
|
||||||
|
// jump to it so the chain can be deleted.
|
||||||
|
ok, err := r.iptablesClient.ChainExists(loc.table, loc.target)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", loc.target, loc.table, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rule = jumpRuleSpec(loc.target)
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jumpRuleSpec builds the iptables rule spec that jumps to target. Create
|
||||||
|
// and cleanup sites share it so the installed and deleted specs cannot drift.
|
||||||
|
func jumpRuleSpec(target string) []string {
|
||||||
|
return []string{"-j", target}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) cleanAclChains() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.cleanInputAclChain(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range r.entries[mangleForwardKey] {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) cleanInputAclChain() error {
|
||||||
|
ok, err := r.iptablesClient.ChainExists(tableFilter, chainACLInput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, rule := range r.entries[chainInput] {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainInput, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range r.entries[chainForward] {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainForward, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.ClearAndDeleteChain(tableFilter, chainACLInput); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) cleanupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if preRule, exists := r.rules[markManglePre]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePre)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if postRule, exists := r.rules[markManglePost]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
285
client/firewall/iptables/dnat_linux.go
Normal file
285
client/firewall/iptables/dnat_linux.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
toDestination := rule.TranslatedAddress.String()
|
||||||
|
switch {
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
// no translated port, use original port
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
// need the "/originalport" suffix to avoid dnat port randomization
|
||||||
|
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := strings.ToLower(string(rule.Protocol))
|
||||||
|
|
||||||
|
rules := make(map[firewall.RuleID]ruleInfo, 3)
|
||||||
|
|
||||||
|
// DNAT rule
|
||||||
|
dnatRule := []string{
|
||||||
|
"!", "-i", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", toDestination,
|
||||||
|
}
|
||||||
|
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||||
|
rules[ruleID+dnatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTRdr,
|
||||||
|
rule: dnatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNAT rule
|
||||||
|
snatRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleID+snatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTNAT,
|
||||||
|
rule: snatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward filtering rule, if fwd policy is DROP
|
||||||
|
forwardRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
}
|
||||||
|
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleID+fwdSuffix] = ruleInfo{
|
||||||
|
table: tableFilter,
|
||||||
|
chain: chainRTFwdOut,
|
||||||
|
rule: forwardRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request forwarding once the rule is about to be installed, releasing
|
||||||
|
// it if installation fails so the refcount tracks the real rules.
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||||
|
log.Errorf("rollback failed: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
r.releaseForwarding()
|
||||||
|
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||||
|
}
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||||
|
// On rollback error, add to rules map for next cleanup
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
r.updateState()
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
var found bool
|
||||||
|
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||||
|
found = true
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID+dnatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||||
|
found = true
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
|
||||||
|
found = true
|
||||||
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID+fwdSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
// Release once, only if the rule was present and removed.
|
||||||
|
if merr == nil && found {
|
||||||
|
r.releaseForwarding()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// releaseForwarding drops one IP forwarding reference, logging any error.
|
||||||
|
func (r *family) releaseForwarding() {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("release IP forwarding: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatRule := []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||||
|
"--dport", strconv.Itoa(int(originalPort)),
|
||||||
|
"-d", localAddr.String(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||||
|
}
|
||||||
|
|
||||||
|
info := ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTRdr,
|
||||||
|
rule: dnatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
|
||||||
|
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[ruleID] = info.rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
|
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||||
|
func (r *family) ensureNATOutputChain() error {
|
||||||
|
if _, exists := r.rules[jumpNATOutput]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||||
|
}
|
||||||
|
if !chainExists {
|
||||||
|
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||||
|
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jumpRule := jumpRuleSpec(chainNATOutput)
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
|
||||||
|
if !chainExists {
|
||||||
|
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||||
|
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNATOutput] = jumpRule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ensureNATOutputChain(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatRule := []string{
|
||||||
|
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
||||||
|
"--dport", strconv.Itoa(int(originalPort)),
|
||||||
|
"-d", localAddr.String(),
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
|
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[ruleID] = dnatRule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||||
|
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
252
client/firewall/iptables/family_linux.go
Normal file
252
client/firewall/iptables/family_linux.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// constants needed to manage and create iptable rules
|
||||||
|
const (
|
||||||
|
tableFilter = "filter"
|
||||||
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
|
||||||
|
// chainACLInput is the peer ACL chain that holds installed
|
||||||
|
// peer-filtering rules.
|
||||||
|
chainACLInput = "NETBIRD-ACL-INPUT"
|
||||||
|
|
||||||
|
// mangleForwardKey is the entries map key for mangle FORWARD guard
|
||||||
|
// rules that prevent external DNAT from bypassing ACL rules.
|
||||||
|
mangleForwardKey chainKey = "MANGLE-FORWARD"
|
||||||
|
|
||||||
|
chainInput = "INPUT"
|
||||||
|
chainPostrouting = "POSTROUTING"
|
||||||
|
chainPrerouting = "PREROUTING"
|
||||||
|
chainForward = "FORWARD"
|
||||||
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
|
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
|
||||||
|
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
|
||||||
|
chainRTPre = "NETBIRD-RT-PRE"
|
||||||
|
chainRTRdr = "NETBIRD-RT-RDR"
|
||||||
|
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||||
|
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
|
||||||
|
|
||||||
|
jumpManglePre = "jump-mangle-pre"
|
||||||
|
jumpNATPre = "jump-nat-pre"
|
||||||
|
jumpNATPost = "jump-nat-post"
|
||||||
|
jumpNATOutput = "jump-nat-output"
|
||||||
|
jumpMSSClamp = "jump-mss-clamp"
|
||||||
|
markManglePre = "mark-mangle-pre"
|
||||||
|
markManglePost = "mark-mangle-post"
|
||||||
|
matchSet = "--match-set"
|
||||||
|
|
||||||
|
dnatSuffix firewall.RuleID = "_dnat"
|
||||||
|
snatSuffix firewall.RuleID = "_snat"
|
||||||
|
fwdSuffix firewall.RuleID = "_fwd"
|
||||||
|
|
||||||
|
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv4TCPHeaderSize = 40
|
||||||
|
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv6TCPHeaderSize = 60
|
||||||
|
)
|
||||||
|
|
||||||
|
type ruleInfo struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
rule []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type routeRules map[firewall.RuleID][]string
|
||||||
|
|
||||||
|
// ruleSpec is a single iptables rule expressed as its argument list
|
||||||
|
// (e.g. {"-i", "wg0", "-j", "DROP"}).
|
||||||
|
type ruleSpec []string
|
||||||
|
|
||||||
|
// chainKey identifies the chain a seeded entry belongs to. It holds
|
||||||
|
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
|
||||||
|
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
|
||||||
|
type chainKey string
|
||||||
|
|
||||||
|
// aclEntries maps a chain to the rules seeded into it to jump into or
|
||||||
|
// guard the netbird ACL chains.
|
||||||
|
type aclEntries map[chainKey][]ruleSpec
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
spec ruleSpec
|
||||||
|
position int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipsetCounter is the shared hash:net refcounter used by peer and
|
||||||
|
// route ACLs alike. The ipset library does not support comments, so
|
||||||
|
// the key is just the set name (string).
|
||||||
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
|
// family holds the per-address-family iptables state. One instance
|
||||||
|
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||||
|
// single family; the top-level Manager owns one for v4 and another
|
||||||
|
// for v6.
|
||||||
|
type family struct {
|
||||||
|
iptablesClient *iptables.IPTables
|
||||||
|
wgIface iFaceMapper
|
||||||
|
v6 bool
|
||||||
|
|
||||||
|
// Peer ACL chain bookkeeping.
|
||||||
|
entries aclEntries
|
||||||
|
optionalEntries map[chainKey][]entry
|
||||||
|
|
||||||
|
// filters holds peer + route filter rules keyed by content hash.
|
||||||
|
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||||
|
filters map[nbid.RuleID]*Rule
|
||||||
|
ipsetCounter *ipsetCounter
|
||||||
|
|
||||||
|
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
|
||||||
|
// plumbing that isn't a filter rule).
|
||||||
|
rules routeRules
|
||||||
|
|
||||||
|
// Routing / NAT.
|
||||||
|
legacyManagement bool
|
||||||
|
mtu uint16
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||||
|
r := &family{
|
||||||
|
iptablesClient: iptablesClient,
|
||||||
|
wgIface: wgIface,
|
||||||
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||||
|
entries: make(aclEntries),
|
||||||
|
optionalEntries: make(map[chainKey][]entry),
|
||||||
|
filters: make(map[nbid.RuleID]*Rule),
|
||||||
|
rules: make(routeRules),
|
||||||
|
mtu: mtu,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ipsetCounter = refcounter.New(
|
||||||
|
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||||
|
return struct{}{}, r.createIpSet(name, sources)
|
||||||
|
},
|
||||||
|
func(name string, _ struct{}) error {
|
||||||
|
return r.deleteIpSet(name)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// init wires the family to the state manager and installs both the
|
||||||
|
// route ACL containers and the peer ACL chain skeleton.
|
||||||
|
func (r *family) init(stateManager *statemanager.Manager) error {
|
||||||
|
r.stateManager = stateManager
|
||||||
|
|
||||||
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.createContainers(); err != nil {
|
||||||
|
return fmt.Errorf("create containers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.seedInitialEntries()
|
||||||
|
r.seedInitialOptionalEntries()
|
||||||
|
|
||||||
|
if err := r.cleanAclChains(); err != nil {
|
||||||
|
return fmt.Errorf("clean acl chains: %w", err)
|
||||||
|
}
|
||||||
|
if err := r.createDefaultChains(); err != nil {
|
||||||
|
return fmt.Errorf("create default chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset tears down all firewall state owned by this family. ACL
|
||||||
|
// chain cleanup runs before route-chain cleanup because the route
|
||||||
|
// chains are still referenced by FORWARD jumps installed during
|
||||||
|
// seedInitialEntries; deleting them first would trip EBUSY.
|
||||||
|
func (r *family) Reset() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.cleanAclChains(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ipsetCounter.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clear(r.rules)
|
||||||
|
clear(r.filters)
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) updateState() {
|
||||||
|
if r.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentState *ShutdownState
|
||||||
|
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
// Clone the rule maps so the persisted state holds a private snapshot.
|
||||||
|
// The live maps keep being mutated by subsequent rule operations while
|
||||||
|
// the state manager marshals the state from its periodic-save goroutine.
|
||||||
|
// Sharing the maps by reference races the two and aborts the process with
|
||||||
|
// a concurrent map iteration and write. The ipset counter guards itself
|
||||||
|
// during marshaling, so it can be shared directly.
|
||||||
|
if r.v6 {
|
||||||
|
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||||
|
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||||
|
currentState.ACLEntries6 = maps.Clone(r.entries)
|
||||||
|
} else {
|
||||||
|
currentState.RouteRules = maps.Clone(r.rules)
|
||||||
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||||
|
currentState.ACLEntries = maps.Clone(r.entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
346
client/firewall/iptables/filter_linux.go
Normal file
346
client/firewall/iptables/filter_linux.go
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddFilterRule installs a packet-filtering rule. With destination
|
||||||
|
// empty, the rule goes to the peer ACL input chain plus a paired
|
||||||
|
// mangle PREROUTING rule for the redirect mark. With destination set
|
||||||
|
// (prefix or named set), it goes to the route ACL forward chain.
|
||||||
|
// Multi-source rules collapse to one iptables rule via the shared
|
||||||
|
// hash:net ipset.
|
||||||
|
func (r *family) AddFilterRule(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||||
|
if existing, ok := r.filters[ruleID]; ok {
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source match: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
|
||||||
|
if err != nil {
|
||||||
|
r.dropSourceMatch(srcMatch)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.filters[ruleID] = rule
|
||||||
|
r.updateState()
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) hasRule(id nbid.RuleID) bool {
|
||||||
|
_, ok := r.filters[id]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasDNATRule reports whether this family owns the DNAT rule set for
|
||||||
|
// the given user id. DNAT rules live in r.rules under the well-known
|
||||||
|
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
|
||||||
|
// to pick the right family.
|
||||||
|
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||||
|
_, ok := r.rules[id+dnatSuffix]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilterRule removes a previously installed filter rule. The
|
||||||
|
// rule's stored chain/table identify where to delete from; source set
|
||||||
|
// references are recovered from the spec via findSets and dropped
|
||||||
|
// from the shared ipset counter.
|
||||||
|
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
pr, ok := r.filters[ruleID]
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("filter rule %s not found", ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteIfExists keeps both deletes idempotent so a retry after a
|
||||||
|
// partial failure does not error on the half that was already removed.
|
||||||
|
var merr *multierror.Error
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, pr.chain, pr.specs...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule from %s: %w", pr.chain, err))
|
||||||
|
}
|
||||||
|
if pr.mangleSpecs != nil {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete mangle rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
// Leave the rule tracked so the caller retries the remaining half.
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The rule is gone from iptables, so untrack it regardless of how the
|
||||||
|
// refcount decrement goes, but surface decrement failures so callers
|
||||||
|
// see the ipset desync.
|
||||||
|
delete(r.filters, ruleID)
|
||||||
|
r.updateState()
|
||||||
|
if err := r.decrementSetCounter(pr.specs); err != nil {
|
||||||
|
return fmt.Errorf("drop source set references: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSets scans an iptables rule spec for "-m set --match-set <name>
|
||||||
|
// <dir>" fragments and returns the named sets in occurrence order.
|
||||||
|
// Used at delete time to drop ipsetCounter references.
|
||||||
|
func findSets(rule []string) []string {
|
||||||
|
var sets []string
|
||||||
|
for i, arg := range rule {
|
||||||
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||||
|
sets = append(sets, rule[i+3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
|
}
|
||||||
|
|
||||||
|
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||||
|
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||||
|
// single prefix inline, or an ipset for multiple sources.
|
||||||
|
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||||
|
switch {
|
||||||
|
case len(sources) == 0:
|
||||||
|
return firewall.Network{}
|
||||||
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
|
return firewall.Network{}
|
||||||
|
case len(sources) == 1:
|
||||||
|
return firewall.Network{Prefix: sources[0]}
|
||||||
|
default:
|
||||||
|
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applySourceMatch returns the iptables match fragment for the rule's
|
||||||
|
// source. For a Set it increments the shared ipset's refcount; for a
|
||||||
|
// Prefix it emits a direct -s match; for the wildcard it returns nil.
|
||||||
|
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||||
|
switch {
|
||||||
|
case network.IsSet():
|
||||||
|
if r.ipsetCounter == nil {
|
||||||
|
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
|
||||||
|
}
|
||||||
|
name := r.ipsetName(network.Set.HashedName())
|
||||||
|
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||||
|
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
|
||||||
|
}
|
||||||
|
return []string{"-m", "set", matchSet, name, "src"}, nil
|
||||||
|
case network.IsPrefix():
|
||||||
|
return []string{"-s", network.Prefix.String()}, nil
|
||||||
|
default:
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dropSourceMatch undoes whatever applySourceMatch reserved when
|
||||||
|
// installing a rule fails. Safe to call when the spec is empty or holds
|
||||||
|
// only inline matchers. Decrement errors are logged but not returned:
|
||||||
|
// the install error is what the caller needs to see.
|
||||||
|
func (r *family) dropSourceMatch(srcMatch []string) {
|
||||||
|
if r.ipsetCounter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, name := range findSets(srcMatch) {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||||
|
log.Errorf("rollback ipset decrement %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrementSetCounter drops ipset references owned by a raw rule spec
|
||||||
|
// stored in r.rules (NAT / legacy route entries). It returns an error
|
||||||
|
// aggregate so the caller surfaces decrement failures.
|
||||||
|
func (r *family) decrementSetCounter(rule []string) error {
|
||||||
|
if r.ipsetCounter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, name := range findSets(rule) {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// installFilterRule assembles and writes one iptables filter-chain
|
||||||
|
// rule. With destination empty the rule lands in the peer ACL input
|
||||||
|
// chain and a paired mangle PREROUTING rule is added for the redirect
|
||||||
|
// mark. With destination set the rule lands in the route ACL forward
|
||||||
|
// chain and there is no mangle pairing.
|
||||||
|
func (r *family) installFilterRule(
|
||||||
|
ruleID nbid.RuleID,
|
||||||
|
srcMatch []string,
|
||||||
|
destination firewall.Network,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (*Rule, error) {
|
||||||
|
isRoute := !destination.IsZero()
|
||||||
|
|
||||||
|
proto := protoForFamily(protocol, r.v6)
|
||||||
|
|
||||||
|
specs := slices.Clone(srcMatch)
|
||||||
|
var destExp []string
|
||||||
|
if isRoute {
|
||||||
|
var err error
|
||||||
|
destExp, err = r.applyNetwork("-d", destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||||
|
}
|
||||||
|
specs = append(specs, destExp...)
|
||||||
|
}
|
||||||
|
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
|
||||||
|
|
||||||
|
var mangleSpecs []string
|
||||||
|
if !isRoute {
|
||||||
|
mangleSpecs = slices.Clone(specs)
|
||||||
|
mangleSpecs = append(mangleSpecs,
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
|
|
||||||
|
chain := chainACLInput
|
||||||
|
if isRoute {
|
||||||
|
chain = chainRTFwdIn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer ACL drops are inserted at position 1 so they precede the
|
||||||
|
// chain's catch-all; route ACL drops are inserted at position 2
|
||||||
|
// to sit immediately after the established/related accept rule.
|
||||||
|
var err error
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
pos := 1
|
||||||
|
if isRoute {
|
||||||
|
pos = 2
|
||||||
|
}
|
||||||
|
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
|
||||||
|
} else {
|
||||||
|
err = r.iptablesClient.Append(tableFilter, chain, specs...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
r.dropSourceMatch(destExp)
|
||||||
|
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The mangle redirect-mark rule is best effort: the filter rule itself
|
||||||
|
// is what enforces the ACL, so a mangle failure must not undo it. Drop
|
||||||
|
// the spec so teardown does not try to remove a rule that was not added.
|
||||||
|
if mangleSpecs != nil {
|
||||||
|
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("add mangle rule: %v", err)
|
||||||
|
mangleSpecs = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Rule{
|
||||||
|
id: ruleID,
|
||||||
|
specs: specs,
|
||||||
|
mangleSpecs: mangleSpecs,
|
||||||
|
chain: chain,
|
||||||
|
v6: r.v6,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyNetwork resolves a firewall.Network into the iptables match
|
||||||
|
// fragment for the given direction flag (-s or -d). Set networks
|
||||||
|
// increment the shared ipset refcount; prefixes emit a direct match;
|
||||||
|
// an empty network returns no spec ("match any").
|
||||||
|
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||||
|
direction := "src"
|
||||||
|
if flag == "-d" {
|
||||||
|
direction = "dst"
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsSet() {
|
||||||
|
name := r.ipsetName(network.Set.HashedName())
|
||||||
|
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||||
|
}
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||||
|
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||||
|
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||||
|
if v6 && protocol == firewall.ProtocolICMP {
|
||||||
|
return "ipv6-icmp"
|
||||||
|
}
|
||||||
|
return string(protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterMatchSpecs returns the proto/port match fragment for a
|
||||||
|
// filtering rule. The source match (-s or -m set) is built by the
|
||||||
|
// caller and prepended.
|
||||||
|
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
|
||||||
|
if protocol != "all" {
|
||||||
|
specs = append(specs, "-p", protocol)
|
||||||
|
}
|
||||||
|
specs = append(specs, applyPort("--sport", sPort)...)
|
||||||
|
specs = append(specs, applyPort("--dport", dPort)...)
|
||||||
|
return specs
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionToStr(action firewall.Action) string {
|
||||||
|
if action == firewall.ActionAccept {
|
||||||
|
return "ACCEPT"
|
||||||
|
}
|
||||||
|
return "DROP"
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
|
if port == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
|
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(port.Values) > 1 {
|
||||||
|
portList := make([]string, len(port.Values))
|
||||||
|
for i, p := range port.Values {
|
||||||
|
portList[i] = strconv.Itoa(int(p))
|
||||||
|
}
|
||||||
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
|
}
|
||||||
93
client/firewall/iptables/interface_allower_linux.go
Normal file
93
client/firewall/iptables/interface_allower_linux.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InterfaceAllower opens the NetBird interface on the iptables filter INPUT
|
||||||
|
// chain so the host firewall doesn't drop traffic the userspace firewall
|
||||||
|
// handles. It is the fallback used when nftables is unavailable (an
|
||||||
|
// iptables-legacy host).
|
||||||
|
//
|
||||||
|
// It opens INPUT only: the userspace router never forwards in the kernel.
|
||||||
|
// firewalld trust is handled by the uspfilter manager, not here.
|
||||||
|
type InterfaceAllower struct {
|
||||||
|
ifaceName string
|
||||||
|
ipt4 *iptables.IPTables
|
||||||
|
// ipt6 is nil when the interface has no IPv6 overlay address.
|
||||||
|
ipt6 *iptables.IPTables
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInterfaceAllower builds an iptables allower for the interface. It returns
|
||||||
|
// an error when iptables is unavailable, so the caller can fall back to
|
||||||
|
// firewalld trust.
|
||||||
|
func NewInterfaceAllower(wgIface iFaceMapper) (*InterfaceAllower, error) {
|
||||||
|
ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("iptables not available: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := ipt4.ListChains(tableFilter); err != nil {
|
||||||
|
return nil, fmt.Errorf("iptables filter table not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &InterfaceAllower{ifaceName: wgIface.Name(), ipt4: ipt4}
|
||||||
|
|
||||||
|
// Missing v6 must not break the v4 path: open v4 only and continue.
|
||||||
|
if wgIface.Address().HasIPv6() {
|
||||||
|
ipt6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("ip6tables not available, opening interface on v4 only: %v", err)
|
||||||
|
} else if _, err := ipt6.ListChains(tableFilter); err != nil {
|
||||||
|
log.Warnf("ip6tables filter table not available, opening interface on v4 only: %v", err)
|
||||||
|
} else {
|
||||||
|
a.ipt6 = ipt6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply inserts the interface accept rule on the filter INPUT chain. It removes
|
||||||
|
// any stale rule first so an unclean exit (e.g. SIGKILL, where Close never ran)
|
||||||
|
// is recovered deterministically rather than accumulating duplicates.
|
||||||
|
func (a *InterfaceAllower) Apply() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, ipt := range a.clients() {
|
||||||
|
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("clean stale interface accept rule: %w", err))
|
||||||
|
}
|
||||||
|
if err := ipt.Insert(tableFilter, chainInput, 1, a.inputRule()...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add interface accept rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close removes the interface accept rule.
|
||||||
|
func (a *InterfaceAllower) Close() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, ipt := range a.clients() {
|
||||||
|
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove interface accept rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *InterfaceAllower) inputRule() []string {
|
||||||
|
return []string{"-i", a.ifaceName, "-j", "ACCEPT"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *InterfaceAllower) clients() []*iptables.IPTables {
|
||||||
|
clients := []*iptables.IPTables{a.ipt4}
|
||||||
|
if a.ipt6 != nil {
|
||||||
|
clients = append(clients, a.ipt6)
|
||||||
|
}
|
||||||
|
return clients
|
||||||
|
}
|
||||||
104
client/firewall/iptables/ipset_linux.go
Normal file
104
client/firewall/iptables/ipset_linux.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/lrh3321/ipset-go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
|
if err := r.createIPSet(setName); err != nil {
|
||||||
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range sources {
|
||||||
|
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||||
|
// The refcounter records nothing when this callback errors,
|
||||||
|
// so destroy the set or it leaks in the kernel. A partial
|
||||||
|
// source set would also fail-open for deny rules, so the
|
||||||
|
// rule must fail rather than install with a missing source.
|
||||||
|
if derr := r.destroyIPSet(setName); derr != nil {
|
||||||
|
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) deleteIpSet(setName string) error {
|
||||||
|
if err := r.destroyIPSet(setName); err != nil {
|
||||||
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("deleted unused ipset %s", setName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
name := r.ipsetName(set.HashedName())
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr == nil {
|
||||||
|
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) ipsetName(name string) string {
|
||||||
|
if r.v6 {
|
||||||
|
return name + "-v6"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
if r.v6 {
|
||||||
|
opts.Family = ipset.FamilyIPV6
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
ip := addr.AsSlice()
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: uint8(prefix.Bits()),
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -18,25 +17,21 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resetter interface {
|
// Manager of iptables firewall. Per-family state (peer ACLs, route
|
||||||
Reset() error
|
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||||
}
|
// by family and provides the public firewall.Manager surface.
|
||||||
|
|
||||||
// Manager of iptables firewall
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
aclMgr *aclManager
|
family4 *family
|
||||||
router *router
|
|
||||||
rawSupported bool
|
rawSupported bool
|
||||||
|
|
||||||
// IPv6 counterparts, nil when no v6 overlay
|
// IPv6 counterparts, nil when no v6 overlay
|
||||||
ipv6Client *iptables.IPTables
|
ipv6Client *iptables.IPTables
|
||||||
aclMgr6 *aclManager
|
family6 *family
|
||||||
router6 *router
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create family: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
if wgIface.Address().HasIPv6() {
|
||||||
@@ -81,21 +71,18 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init ip6tables: %w", err)
|
return fmt.Errorf("init ip6tables: %w", err)
|
||||||
}
|
}
|
||||||
m.ipv6Client = ip6Client
|
|
||||||
|
|
||||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
family6, err := newFamily(ip6Client, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create v6 router: %w", err)
|
return fmt.Errorf("create v6 family: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
// Share the same IP forwarding state with the v4 family, since
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
family6.ipFwdState = m.family4.ipFwdState
|
||||||
|
|
||||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
m.ipv6Client = ip6Client
|
||||||
if err != nil {
|
m.family6 = family6
|
||||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -109,7 +96,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.router.mtu,
|
MTU: m.family4.mtu,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
stateManager.RegisterState(state)
|
stateManager.RegisterState(state)
|
||||||
@@ -141,31 +128,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initChains initializes router and ACL chains for both address families,
|
// initChains initializes the per-family firewall state for both
|
||||||
// rolling back on failure.
|
// address families, rolling back on failure.
|
||||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||||
type initStep struct {
|
type initStep struct {
|
||||||
name string
|
name string
|
||||||
init func(*statemanager.Manager) error
|
r *family
|
||||||
mgr resetter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
steps := []initStep{
|
steps := []initStep{{"v4", m.family4}}
|
||||||
{"router", m.router.init, m.router},
|
|
||||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
steps = append(steps,
|
steps = append(steps, initStep{"v6", m.family6})
|
||||||
initStep{"v6 router", m.router6.init, m.router6},
|
|
||||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var initialized []initStep
|
var initialized []initStep
|
||||||
for _, s := range steps {
|
for _, s := range steps {
|
||||||
if err := s.init(stateManager); err != nil {
|
if err := s.r.init(stateManager); err != nil {
|
||||||
for i := len(initialized) - 1; i >= 0; i-- {
|
for i := len(initialized) - 1; i >= 0; i-- {
|
||||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
if rerr := initialized[i].r.Reset(); rerr != nil {
|
||||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,84 +156,50 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
|
||||||
//
|
// docs for destination semantics. Sources are a single address family;
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// the rule is dispatched to the matching v4 / v6 backend.
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddFilterRule(
|
||||||
id []byte,
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination firewall.Network,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort, dPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
if len(sources) == 0 {
|
||||||
|
return nil, firewall.ErrNoSources
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if isIPv6RouteRule(sources, destination) {
|
fam := m.family4
|
||||||
|
if isIPv6Rule(sources, destination) {
|
||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
fam = m.family6
|
||||||
}
|
}
|
||||||
|
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
// DeleteFilterRule removes a rule previously added via AddFilterRule.
|
||||||
if destination.IsPrefix() {
|
// The rule is looked up by id in each family's filter cache.
|
||||||
return destination.Prefix.Addr().Is6()
|
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
id := rule.ID()
|
||||||
return m.aclMgr6.DeletePeerRule(rule)
|
if m.family4.hasRule(id) {
|
||||||
|
return m.family4.DeleteFilterRule(rule)
|
||||||
}
|
}
|
||||||
return m.aclMgr.DeletePeerRule(rule)
|
if m.hasIPv6() && m.family6.hasRule(id) {
|
||||||
}
|
return m.family6.DeleteFilterRule(rule)
|
||||||
|
|
||||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
return ok && r.v6
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule.
|
|
||||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
|
||||||
return m.router6.DeleteRouteRule(rule)
|
|
||||||
}
|
}
|
||||||
return m.router.DeleteRouteRule(rule)
|
log.Debugf("filter rule %s not found in any family", id)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
@@ -272,10 +218,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddNatRule(pair)
|
return m.family6.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router.AddNatRule(pair); err != nil {
|
if err := m.family4.AddNatRule(pair); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,7 +230,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,18 +246,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.router6.RemoveNatRule(pair)
|
return m.family6.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -320,11 +266,14 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -341,19 +290,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.aclMgr6.Reset(); err != nil {
|
if err := m.family6.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||||
}
|
|
||||||
if err := m.router6.Reset(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.aclMgr.Reset(); err != nil {
|
if err := m.family4.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||||
}
|
|
||||||
if err := m.router.Reset(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||||
@@ -372,27 +315,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic.
|
|
||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
@@ -402,14 +324,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -424,9 +346,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddDNATRule(rule)
|
return m.family6.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.router.AddDNATRule(rule)
|
return m.family4.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
@@ -434,10 +356,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
|
||||||
return m.router6.DeleteDNATRule(rule)
|
return m.family6.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.router.DeleteDNATRule(rule)
|
return m.family4.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSet updates the set with the given prefixes
|
// UpdateSet updates the set with the given prefixes
|
||||||
@@ -454,12 +376,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -476,9 +398,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
@@ -490,9 +412,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
@@ -504,9 +426,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
@@ -518,14 +440,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRaw = "NETBIRD-RAW"
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
chainOUTPUT = "OUTPUT"
|
chainOutput = "OUTPUT"
|
||||||
tableRaw = "raw"
|
tableRaw = "raw"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -600,15 +522,15 @@ func (m *Manager) initNoTrackChain() error {
|
|||||||
|
|
||||||
jumpRule := []string{"-j", chainNameRaw}
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
|
||||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
log.Debugf("delete orphan chain: %v", delErr)
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("add output jump rule: %w", err)
|
return fmt.Errorf("add output jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
|
||||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
|
||||||
log.Debugf("delete output jump rule: %v", delErr)
|
log.Debugf("delete output jump rule: %v", delErr)
|
||||||
}
|
}
|
||||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
@@ -635,11 +557,11 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
|
|
||||||
jumpRule := []string{"-j", chainNameRaw}
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
|
||||||
return fmt.Errorf("remove output jump rule: %w", err)
|
return fmt.Errorf("remove output jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
|
||||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -654,3 +576,13 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
|
||||||
|
// the destination prefix when set, otherwise from the (single-family)
|
||||||
|
// sources.
|
||||||
|
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
|
}
|
||||||
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -65,46 +67,39 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
rr := rule2.(*Rule)
|
||||||
rr := r.(*Rule)
|
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
require.NoError(t, err, "failed check chain exists")
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -126,15 +121,13 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
|||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{22}}
|
port := &fw.Port{Values: []uint16{22}}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
|
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
require.NotEmpty(t, rule, "deny rule should not be empty")
|
require.NotNil(t, rule, "deny rule should not be nil")
|
||||||
|
|
||||||
// Verify the rule was added by checking iptables
|
// Verify the rule was added by checking iptables
|
||||||
for _, r := range rule {
|
rr := rule.(*Rule)
|
||||||
rr := r.(*Rule)
|
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("deny rule precedence test", func(t *testing.T) {
|
t.Run("deny rule precedence test", func(t *testing.T) {
|
||||||
@@ -142,36 +135,40 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
|||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
|
||||||
// Add accept rule first
|
// Add accept rule first
|
||||||
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
|
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add accept rule")
|
require.NoError(t, err, "failed to add accept rule")
|
||||||
|
|
||||||
// Add deny rule second for same IP/port - this should take precedence
|
// Add deny rule second for same IP/port - this should take precedence
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
|
|
||||||
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
||||||
rules, err := ipv4Client.List("filter", chainNameInputRules)
|
rules, err := ipv4Client.List("filter", chainACLInput)
|
||||||
require.NoError(t, err, "failed to list iptables rules")
|
require.NoError(t, err, "failed to list iptables rules")
|
||||||
|
|
||||||
// Debug: print all rules
|
// Debug: print all rules
|
||||||
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
|
t.Logf("All iptables rules in chain %s:", chainACLInput)
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
t.Logf(" [%d] %s", i, rule)
|
t.Logf(" [%d] %s", i, rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
|
||||||
|
// match. Match on that shape instead of the legacy
|
||||||
|
// per-(action,port) ipset names ("deny-http"/"accept-http")
|
||||||
|
// that this test predates.
|
||||||
|
srcMatch := fmt.Sprintf("-s %s/32", ip)
|
||||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
if strings.Contains(rule, "DROP") {
|
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
|
||||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
continue
|
||||||
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
|
|
||||||
denyRuleIndex = i
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if strings.Contains(rule, "ACCEPT") {
|
if strings.Contains(rule, "-j DROP") {
|
||||||
|
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||||
|
denyRuleIndex = i
|
||||||
|
}
|
||||||
|
if strings.Contains(rule, "-j ACCEPT") {
|
||||||
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
||||||
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
|
acceptRuleIndex = i
|
||||||
acceptRuleIndex = i
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +193,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
|
||||||
manager, err := Create(mock, iface.DefaultMTU)
|
manager, err := Create(mock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -210,27 +206,39 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||||
for _, r := range rule2 {
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NotNil(t, rule2)
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Contains(t, rule2.(*Rule).specs, "-s",
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
"single-source rule should use direct -s match, not an ipset")
|
||||||
}
|
require.Empty(t, findSets(rule2.(*Rule).specs),
|
||||||
|
"single-source rule should not allocate a shared ipset")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete single-source rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||||
err := manager.DeletePeerRule(r)
|
})
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
t.Run("multi-source uses shared ipset", func(t *testing.T) {
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
|
||||||
|
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
|
||||||
|
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
|
||||||
}
|
}
|
||||||
|
port := &fw.Port{Values: []uint16{8080}}
|
||||||
|
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "failed to add multi-source rule")
|
||||||
|
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
|
||||||
|
sets := findSets(multi.(*Rule).specs)
|
||||||
|
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
|
||||||
|
|
||||||
|
require.NoError(t, manager.DeleteFilterRule(multi))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
@@ -281,7 +289,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 11. MSS clamping rule for outbound traffic
|
// 11. MSS clamping rule for outbound traffic
|
||||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
|
||||||
require.True(t, exists, "postrouting jump rule should exist")
|
require.True(t, exists, "postrouting jump rule should exist")
|
||||||
|
|
||||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
|
||||||
require.True(t, exists, "prerouting jump rule should exist")
|
require.True(t, exists, "prerouting jump rule should exist")
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
err = manager.AddNatRule(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "marking rule should be inserted")
|
require.NoError(t, err, "marking rule should be inserted")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||||
markingRule := []string{
|
markingRule := []string{
|
||||||
"-i", ifaceMock.Name(),
|
"-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "marking rule should be created")
|
require.True(t, exists, "marking rule should be created")
|
||||||
foundRule, found := manager.rules[natRuleKey]
|
foundRule, found := manager.rules[natRuleKey]
|
||||||
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Check inverse rule
|
// Check inverse rule
|
||||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||||
inverseMarkingRule := []string{
|
inverseMarkingRule := []string{
|
||||||
"!", "-i", ifaceMock.Name(),
|
"!", "-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "inverse marking rule should be created")
|
require.True(t, exists, "inverse marking rule should be created")
|
||||||
foundRule, found := manager.rules[inverseRuleKey]
|
foundRule, found := manager.rules[inverseRuleKey]
|
||||||
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||||
markingRule := []string{
|
markingRule := []string{
|
||||||
"-i", ifaceMock.Name(),
|
"-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||||
require.False(t, exists, "marking rule should not exist")
|
require.False(t, exists, "marking rule should not exist")
|
||||||
|
|
||||||
_, found := manager.rules[natRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Check inverse rule removal
|
// Check inverse rule removal
|
||||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||||
inverseMarkingRule := []string{
|
inverseMarkingRule := []string{
|
||||||
"!", "-i", ifaceMock.Name(),
|
"!", "-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||||
require.False(t, exists, "inverse marking rule should not exist")
|
require.False(t, exists, "inverse marking rule should not exist")
|
||||||
|
|
||||||
_, found = manager.rules[inverseRuleKey]
|
_, found = manager.rules[inverseRuleKey]
|
||||||
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "Failed to create iptables client")
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router manager")
|
require.NoError(t, err, "Failed to create family manager")
|
||||||
require.NoError(t, r.init(nil))
|
require.NoError(t, r.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := r.Reset()
|
err := r.Reset()
|
||||||
require.NoError(t, err, "Failed to reset router")
|
require.NoError(t, err, "Failed to reset family")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddFilterRule failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
stored, ok := r.filters[ruleKey.ID()]
|
||||||
rule, ok := r.rules[ruleKey.ID()]
|
require.True(t, ok, "rule not stored in filters")
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
t.Logf("Internal rule: %v", stored.specs)
|
||||||
|
|
||||||
// Log the internal rule
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
|
||||||
t.Logf("Internal rule: %v", rule)
|
|
||||||
|
|
||||||
// Check if the rule exists in iptables
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
|
||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
var source firewall.Network
|
|
||||||
if len(tt.sources) > 1 {
|
|
||||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
|
||||||
} else if len(tt.sources) > 0 {
|
|
||||||
source.Prefix = tt.sources[0]
|
|
||||||
}
|
|
||||||
// Verify rule content
|
|
||||||
params := routeFilteringRuleParams{
|
|
||||||
Source: source,
|
|
||||||
Destination: firewall.Network{Prefix: tt.destination},
|
|
||||||
Proto: tt.proto,
|
|
||||||
SPort: tt.sPort,
|
|
||||||
DPort: tt.dPort,
|
|
||||||
Action: tt.action,
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
|
||||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
|
||||||
|
|
||||||
if tt.expectSet {
|
if tt.expectSet {
|
||||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
|
||||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
|
||||||
|
|
||||||
// Check if the set was created
|
|
||||||
_, exists := r.ipsetCounter.Get(setName)
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
assert.True(t, exists, "IPSet not created")
|
assert.True(t, exists, "IPSet not created")
|
||||||
|
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||||
|
|
||||||
// Clean up
|
|
||||||
err = r.DeleteRouteRule(ruleKey)
|
|
||||||
require.NoError(t, err, "Failed to delete rule")
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindSetNameInRule(t *testing.T) {
|
func TestFindSetNameInRule(t *testing.T) {
|
||||||
r := &router{}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
rule []string
|
rule []string
|
||||||
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
result := r.findSets(tc.rule)
|
result := findSets(tc.rule)
|
||||||
|
|
||||||
if len(result) != len(tc.expected) {
|
if len(result) != len(tc.expected) {
|
||||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||||
|
|||||||
263
client/firewall/iptables/routing_linux.go
Normal file
263
client/firewall/iptables/routing_linux.go
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if r.legacyManagement {
|
||||||
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
|
if err := r.addNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("add nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
|
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if pair.Masquerade {
|
||||||
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
|
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||||
|
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
|
||||||
|
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||||
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLegacyManagement returns the current legacy management mode
|
||||||
|
func (r *family) GetLegacyManagement() bool {
|
||||||
|
return r.legacyManagement
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||||
|
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||||
|
r.legacyManagement = isLegacy
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||||
|
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for k, rule := range r.rules {
|
||||||
|
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addPostroutingRules() error {
|
||||||
|
// First rule for outbound masquerade
|
||||||
|
rule1 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
"!", "-o", "lo",
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||||
|
return fmt.Errorf("add outbound masquerade rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-outbound"] = rule1
|
||||||
|
|
||||||
|
// Second rule for return traffic masquerade
|
||||||
|
rule2 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||||
|
return fmt.Errorf("add return masquerade rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-return"] = rule2
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
func (r *family) addMSSClampingRules() error {
|
||||||
|
overhead := uint16(ipv4TCPHeaderSize)
|
||||||
|
if r.v6 {
|
||||||
|
overhead = ipv6TCPHeaderSize
|
||||||
|
}
|
||||||
|
mss := r.mtu - overhead
|
||||||
|
|
||||||
|
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||||
|
jumpRule := jumpRuleSpec(chainRTMSSClamp)
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpMSSClamp] = jumpRule
|
||||||
|
|
||||||
|
ruleOut := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", "tcp",
|
||||||
|
"--tcp-flags", "SYN,RST", "SYN",
|
||||||
|
"-j", "TCPMSS",
|
||||||
|
"--set-mss", fmt.Sprintf("%d", mss),
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
|
||||||
|
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules["mss-clamp-out"] = ruleOut
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) insertEstablishedRule(chain string) error {
|
||||||
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
|
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := firewall.RuleID("established-" + chain)
|
||||||
|
r.rules[ruleID] = establishedRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.NatFormat)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||||
|
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||||
|
if pair.Inverse {
|
||||||
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := []string{"-i", r.wgIface.Name()}
|
||||||
|
if pair.Inverse {
|
||||||
|
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule,
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
)
|
||||||
|
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -s: %w", err)
|
||||||
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -d: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
rule = append(rule,
|
||||||
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
|
||||||
|
r.dropSourceMatch(rule)
|
||||||
|
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.NatFormat)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||||
|
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("marking rule %s not found", ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,18 +1,20 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
// Rule to handle management of rules
|
import "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
type Rule struct {
|
|
||||||
ruleID string
|
|
||||||
ipsetName string
|
|
||||||
|
|
||||||
|
// Rule to handle management of rules. Source set membership (when the
|
||||||
|
// rule was built against a shared hash:net ipset) is encoded in specs;
|
||||||
|
// DeleteFilterRule recovers it via findSets so the refcounter can drop
|
||||||
|
// the right reference.
|
||||||
|
type Rule struct {
|
||||||
|
id manager.RuleID
|
||||||
specs []string
|
specs []string
|
||||||
mangleSpecs []string
|
mangleSpecs []string
|
||||||
ip string
|
|
||||||
chain string
|
chain string
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *Rule) ID() string {
|
func (r *Rule) ID() manager.RuleID {
|
||||||
return r.ruleID
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"maps"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ipList struct {
|
|
||||||
ips map[string]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpList(ip string) *ipList {
|
|
||||||
ips := make(map[string]struct{})
|
|
||||||
ips[ip] = struct{}{}
|
|
||||||
|
|
||||||
return &ipList{
|
|
||||||
ips: ips,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipList) addIP(ip string) {
|
|
||||||
s.ips[ip] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clone returns a deep copy of the ipList with its own ips map.
|
|
||||||
func (s *ipList) clone() *ipList {
|
|
||||||
if s == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &ipList{ips: maps.Clone(s.ips)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler
|
|
||||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
IPs map[string]struct{} `json:"ips"`
|
|
||||||
}{
|
|
||||||
IPs: s.ips,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON implements json.Unmarshaler
|
|
||||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
|
||||||
temp := struct {
|
|
||||||
IPs map[string]struct{} `json:"ips"`
|
|
||||||
}{}
|
|
||||||
if err := json.Unmarshal(data, &temp); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.ips = temp.IPs
|
|
||||||
|
|
||||||
if temp.IPs == nil {
|
|
||||||
temp.IPs = make(map[string]struct{})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipsetStore struct {
|
|
||||||
ipsets map[string]*ipList
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
|
||||||
return &ipsetStore{
|
|
||||||
ipsets: make(map[string]*ipList),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
|
||||||
// independent ipList entries.
|
|
||||||
func (s *ipsetStore) clone() *ipsetStore {
|
|
||||||
if s == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
|
||||||
for name, list := range s.ipsets {
|
|
||||||
cloned.ipsets[name] = list.clone()
|
|
||||||
}
|
|
||||||
return cloned
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
|
||||||
r, ok := s.ipsets[ipsetName]
|
|
||||||
return r, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
|
||||||
s.ipsets[ipsetName] = list
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
|
||||||
delete(s.ipsets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ipsetNames() []string {
|
|
||||||
names := make([]string, 0, len(s.ipsets))
|
|
||||||
for name := range s.ipsets {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler
|
|
||||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(struct {
|
|
||||||
IPSets map[string]*ipList `json:"ipsets"`
|
|
||||||
}{
|
|
||||||
IPSets: s.ipsets,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON implements json.Unmarshaler
|
|
||||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
|
||||||
temp := struct {
|
|
||||||
IPSets map[string]*ipList `json:"ipsets"`
|
|
||||||
}{}
|
|
||||||
if err := json.Unmarshal(data, &temp); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.ipsets = temp.IPSets
|
|
||||||
|
|
||||||
if temp.IPSets == nil {
|
|
||||||
temp.IPSets = make(map[string]*ipList)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -29,17 +29,13 @@ type ShutdownState struct {
|
|||||||
|
|
||||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
|
||||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
|
||||||
|
|
||||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
|
||||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
|
||||||
|
|
||||||
// IPv6 counterparts
|
|
||||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||||
|
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
|
||||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
|
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -57,17 +53,14 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.RouteRules != nil {
|
if s.RouteRules != nil {
|
||||||
ipt.router.rules = s.RouteRules
|
ipt.family4.rules = s.RouteRules
|
||||||
}
|
}
|
||||||
if s.RouteIPsetCounter != nil {
|
if s.RouteIPsetCounter != nil {
|
||||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ACLEntries != nil {
|
if s.ACLEntries != nil {
|
||||||
ipt.aclMgr.entries = s.ACLEntries
|
ipt.family4.entries = s.ACLEntries
|
||||||
}
|
|
||||||
if s.ACLIPsetStore != nil {
|
|
||||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up v6 state even if the current run has no IPv6.
|
// Clean up v6 state even if the current run has no IPv6.
|
||||||
@@ -79,16 +72,13 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
}
|
}
|
||||||
if ipt.hasIPv6() {
|
if ipt.hasIPv6() {
|
||||||
if s.RouteRules6 != nil {
|
if s.RouteRules6 != nil {
|
||||||
ipt.router6.rules = s.RouteRules6
|
ipt.family6.rules = s.RouteRules6
|
||||||
}
|
}
|
||||||
if s.RouteIPsetCounter6 != nil {
|
if s.RouteIPsetCounter6 != nil {
|
||||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||||
}
|
}
|
||||||
if s.ACLEntries6 != nil {
|
if s.ACLEntries6 != nil {
|
||||||
ipt.aclMgr6.entries = s.ACLEntries6
|
ipt.family6.entries = s.ACLEntries6
|
||||||
}
|
|
||||||
if s.ACLIPsetStore6 != nil {
|
|
||||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pfx(ip net.IP) []netip.Prefix {
|
||||||
|
if ip == nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
if ip.IsUnspecified() {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
a, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||||
|
}
|
||||||
|
a = a.Unmap()
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package manager
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
@@ -16,6 +15,12 @@ import (
|
|||||||
// method but the IPv6 firewall components were not initialized.
|
// method but the IPv6 firewall components were not initialized.
|
||||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||||
|
|
||||||
|
// ErrNoSources is returned when AddFilterRule is called with an empty
|
||||||
|
// source list. "Match any source" must be expressed explicitly with a
|
||||||
|
// /0 prefix; an empty list is a caller error and is rejected rather
|
||||||
|
// than silently widening the rule to every source.
|
||||||
|
var ErrNoSources = errors.New("rule has no sources")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ForwardingFormatPrefix = "netbird-fwd-"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
@@ -23,13 +28,18 @@ const (
|
|||||||
NatFormat = "netbird-nat-%s-%t"
|
NatFormat = "netbird-nat-%s-%t"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RuleID identifies a firewall rule. It is a typed string so the
|
||||||
|
// compiler catches accidental mixing with arbitrary string keys. It is
|
||||||
|
// only an identifier and does not implement Rule.
|
||||||
|
type RuleID string
|
||||||
|
|
||||||
// Rule abstraction should be implemented by each firewall manager
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
//
|
//
|
||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
// of the properties to hold data of the created rule
|
// of the properties to hold data of the created rule
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
ID() string
|
ID() RuleID
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// RuleDirection is the traffic direction which a rule is applied
|
||||||
@@ -91,6 +101,13 @@ func (d Network) IsPrefix() bool {
|
|||||||
return d.Prefix.IsValid()
|
return d.Prefix.IsValid()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsZero returns true if the network designates no destination, i.e. it
|
||||||
|
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
|
||||||
|
// one carries a prefix or set destination.
|
||||||
|
func (d Network) IsZero() bool {
|
||||||
|
return !d.IsPrefix() && !d.IsSet()
|
||||||
|
}
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
// Manager is the high level abstraction of a firewall manager
|
||||||
//
|
//
|
||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
@@ -98,46 +115,42 @@ func (d Network) IsPrefix() bool {
|
|||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init(stateManager *statemanager.Manager) error
|
Init(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AddFilterRule adds a packet-filtering rule to the firewall.
|
||||||
AllowNetbird() error
|
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If destination is the zero Network, the rule applies to traffic
|
||||||
// rule ID as comment for the rule
|
// inbound to this node, i.e. peer ACL semantics, installed in
|
||||||
|
// the kernel's input chain. If destination is set (prefix or
|
||||||
|
// set), the rule applies to forwarded traffic with that
|
||||||
|
// destination, route ACL semantics, installed in the forward
|
||||||
|
// chain.
|
||||||
//
|
//
|
||||||
// Note: Callers should call Flush() after adding rules to ensure
|
// sources must be a single address family; the caller splits mixed
|
||||||
// they are applied to the kernel and rule handles are refreshed.
|
// families and calls once per family. "Match any source" must be
|
||||||
AddPeerFiltering(
|
// expressed with an explicit /0 prefix; an empty sources list is
|
||||||
|
// rejected with ErrNoSources so a zeroed list can never widen a
|
||||||
|
// rule to every source.
|
||||||
|
//
|
||||||
|
// Note: callers should call Flush() after adding rules.
|
||||||
|
AddFilterRule(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip net.IP,
|
sources []netip.Prefix,
|
||||||
|
destination Network,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
) (Rule, error)
|
||||||
) ([]Rule, error)
|
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeleteFilterRule removes a filtering rule previously added via
|
||||||
DeletePeerRule(rule Rule) error
|
// AddFilterRule. The rule's own type identifies whether it lives
|
||||||
|
// in the peer (input) or route (forward) path.
|
||||||
|
DeleteFilterRule(rule Rule) error
|
||||||
|
|
||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
IsStateful() bool
|
IsStateful() bool
|
||||||
|
|
||||||
AddRouteFiltering(
|
|
||||||
id []byte,
|
|
||||||
sources []netip.Prefix,
|
|
||||||
destination Network,
|
|
||||||
proto Protocol,
|
|
||||||
sPort, dPort *Port,
|
|
||||||
action Action,
|
|
||||||
) (Rule, error)
|
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
|
||||||
DeleteRouteRule(rule Rule) error
|
|
||||||
|
|
||||||
// AddNatRule inserts a routing NAT rule
|
// AddNatRule inserts a routing NAT rule
|
||||||
AddNatRule(pair RouterPair) error
|
AddNatRule(pair RouterPair) error
|
||||||
|
|
||||||
@@ -185,8 +198,9 @@ type Manager interface {
|
|||||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
// GenKey builds the rule id for this pair from the given format.
|
||||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
func (p RouterPair) GenKey(format string) RuleID {
|
||||||
|
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
|
||||||
}
|
}
|
||||||
|
|
||||||
// LegacyManager defines the interface for legacy management operations
|
// LegacyManager defines the interface for legacy management operations
|
||||||
@@ -242,6 +256,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
|||||||
return merged
|
return merged
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
|
||||||
|
// plain v4 form, shifting the prefix length out of the 96-bit mapped
|
||||||
|
// range. Other prefixes are returned unchanged. Keeping prefixes
|
||||||
|
// unmapped ensures v4 rules match consistently and the match builders
|
||||||
|
// read the correct address length.
|
||||||
|
func UnmapPrefix(p netip.Prefix) netip.Prefix {
|
||||||
|
addr := p.Addr()
|
||||||
|
if !addr.Is4In6() {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
bits := max(p.Bits()-96, 0)
|
||||||
|
return netip.PrefixFrom(addr.Unmap(), bits)
|
||||||
|
}
|
||||||
|
|
||||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||||
func SortPrefixes(prefixes []netip.Prefix) {
|
func SortPrefixes(prefixes []netip.Prefix) {
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ type ForwardRule struct {
|
|||||||
TranslatedPort Port
|
TranslatedPort Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ForwardRule) ID() string {
|
func (r ForwardRule) ID() RuleID {
|
||||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||||
r.Protocol,
|
r.Protocol,
|
||||||
r.DestinationPort.String(),
|
r.DestinationPort.String(),
|
||||||
r.TranslatedAddress.String(),
|
r.TranslatedAddress.String(),
|
||||||
r.TranslatedPort.String())
|
r.TranslatedPort.String())
|
||||||
return id
|
return RuleID(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ForwardRule) String() string {
|
func (r ForwardRule) String() string {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (h Set) Comment() string {
|
|||||||
|
|
||||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||||
// sort for consistent naming
|
prefixes = slices.Clone(prefixes)
|
||||||
SortPrefixes(prefixes)
|
SortPrefixes(prefixes)
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
|
|||||||
@@ -1,713 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
|
||||||
chainNameInputRules = "netbird-acl-input-rules"
|
|
||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
|
||||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
|
||||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
|
||||||
)
|
|
||||||
|
|
||||||
const flushError = "flush: %w"
|
|
||||||
|
|
||||||
type AclManager struct {
|
|
||||||
rConn *nftables.Conn
|
|
||||||
sConn *nftables.Conn
|
|
||||||
wgIface iFaceMapper
|
|
||||||
routingFwChainName string
|
|
||||||
af addrFamily
|
|
||||||
|
|
||||||
workTable *nftables.Table
|
|
||||||
chainInputRules *nftables.Chain
|
|
||||||
chainPrerouting *nftables.Chain
|
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
|
||||||
rules map[string]*Rule
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
|
||||||
// and is permanent. Using same connection for both type of operations
|
|
||||||
// overloads netlink with high amount of rules ( > 10000)
|
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create nf conn: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AclManager{
|
|
||||||
rConn: &nftables.Conn{},
|
|
||||||
sConn: sConn,
|
|
||||||
wgIface: wgIface,
|
|
||||||
workTable: table,
|
|
||||||
routingFwChainName: routingFwChainName,
|
|
||||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
|
||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
|
||||||
rules: make(map[string]*Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) init(workTable *nftables.Table) error {
|
|
||||||
m.workTable = workTable
|
|
||||||
return m.createDefaultChains()
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
|
||||||
//
|
|
||||||
// If comment argument is empty firewall manager should set
|
|
||||||
// rule ID as comment for the rule
|
|
||||||
func (m *AclManager) AddPeerFiltering(
|
|
||||||
id []byte,
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
var ipset *nftables.Set
|
|
||||||
if ipsetName != "" {
|
|
||||||
var err error
|
|
||||||
ipset, err = m.addIpToSet(ipsetName, ip)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newRules = append(newRules, ioRule)
|
|
||||||
return newRules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
|
||||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.nftSet == nil {
|
|
||||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
|
||||||
}
|
|
||||||
if r.mangleRule != nil {
|
|
||||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(m.rules, r.ID())
|
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
|
||||||
if !ok {
|
|
||||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
|
||||||
}
|
|
||||||
if r.mangleRule != nil {
|
|
||||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(m.rules, r.ID())
|
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
|
||||||
}
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if after delete, set still contains other IPs,
|
|
||||||
// no need to delete firewall rule and we should exit here
|
|
||||||
if len(ips) > 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
|
||||||
}
|
|
||||||
if r.mangleRule != nil {
|
|
||||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
|
||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.rules, r.ID())
|
|
||||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
|
||||||
|
|
||||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we delete last IP from the set, that means we need to delete
|
|
||||||
// set itself and associated firewall rule too
|
|
||||||
m.rConn.FlushSet(r.nftSet)
|
|
||||||
m.rConn.DelSet(r.nftSet)
|
|
||||||
m.ipsetStore.deleteIpset(r.nftSet.Name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
|
||||||
func (m *AclManager) createDefaultAllowRules() error {
|
|
||||||
expIn := []expr.Any{
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainInputRules,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expIn,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
|
||||||
//
|
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
|
||||||
func (m *AclManager) Flush() error {
|
|
||||||
if err := m.flushWithBackoff(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addIOFiltering(
|
|
||||||
ip net.IP,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipset *nftables.Set,
|
|
||||||
) (*Rule, error) {
|
|
||||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
|
||||||
return &Rule{
|
|
||||||
nftRule: r.nftRule,
|
|
||||||
mangleRule: r.mangleRule,
|
|
||||||
nftSet: r.nftSet,
|
|
||||||
ruleID: r.ruleID,
|
|
||||||
ip: ip,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var expressions []expr.Any
|
|
||||||
|
|
||||||
if proto != firewall.ProtocolALL {
|
|
||||||
expressions = append(expressions, &expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: m.af.protoOffset,
|
|
||||||
Len: uint32(1),
|
|
||||||
})
|
|
||||||
|
|
||||||
protoData, err := m.af.protoNum(proto)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions, &expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: []byte{protoData},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIP := ipToBytes(ip, m.af)
|
|
||||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
|
||||||
// in that case not add IP match expression into the rule definition
|
|
||||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: m.af.srcAddrOffset,
|
|
||||||
Len: m.af.addrLen,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
// add individual IP for match if no ipset defined
|
|
||||||
if ipset == nil {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: rawIP,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ipset.Name,
|
|
||||||
SetID: ipset.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions, applyPort(sPort, true)...)
|
|
||||||
expressions = append(expressions, applyPort(dPort, false)...)
|
|
||||||
|
|
||||||
mainExpressions := slices.Clone(expressions)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case firewall.ActionAccept:
|
|
||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
|
||||||
case firewall.ActionDrop:
|
|
||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := []byte(ruleId)
|
|
||||||
|
|
||||||
chain := m.chainInputRules
|
|
||||||
rule := &nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: mainExpressions,
|
|
||||||
UserData: userData,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
|
||||||
var nftRule *nftables.Rule
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
nftRule = m.rConn.InsertRule(rule)
|
|
||||||
} else {
|
|
||||||
nftRule = m.rConn.AddRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleStruct := &Rule{
|
|
||||||
nftRule: nftRule,
|
|
||||||
// best effort mangle rule
|
|
||||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
|
||||||
nftSet: ipset,
|
|
||||||
ruleID: ruleId,
|
|
||||||
ip: ip,
|
|
||||||
}
|
|
||||||
m.rules[ruleId] = ruleStruct
|
|
||||||
if ipset != nil {
|
|
||||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ruleStruct, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
|
||||||
if m.chainPrerouting == nil {
|
|
||||||
log.Warn("prerouting chain is not created")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
preroutingExprs := slices.Clone(expressions)
|
|
||||||
|
|
||||||
// interface
|
|
||||||
preroutingExprs = append([]expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}, preroutingExprs...)
|
|
||||||
|
|
||||||
// local destination and mark
|
|
||||||
preroutingExprs = append(preroutingExprs,
|
|
||||||
&expr.Fib{
|
|
||||||
Register: 1,
|
|
||||||
ResultADDRTYPE: true,
|
|
||||||
FlagDADDR: true,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
|
||||||
},
|
|
||||||
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainPrerouting,
|
|
||||||
Exprs: preroutingExprs,
|
|
||||||
UserData: userData,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nfRule
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) createDefaultChains() (err error) {
|
|
||||||
// chainNameInputRules
|
|
||||||
chain := m.createChain(chainNameInputRules)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
m.chainInputRules = chain
|
|
||||||
|
|
||||||
// netbird-acl-input-filter
|
|
||||||
// type filter hook input priority filter; policy accept;
|
|
||||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
|
||||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
|
||||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
|
||||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
|
||||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
|
||||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
|
||||||
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
|
||||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
|
||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
|
||||||
// netbird peer IP.
|
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
|
||||||
// Chain is created by route manager
|
|
||||||
// TODO: move creation to a common place
|
|
||||||
m.chainPrerouting = &nftables.Chain{
|
|
||||||
Name: chainNameManglePrerouting,
|
|
||||||
Table: m.workTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityMangle,
|
|
||||||
}
|
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|
||||||
m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: chainFwFilter,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: m.routingFwChainName,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: chainFwFilter,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: m.workTable,
|
|
||||||
}
|
|
||||||
|
|
||||||
chain = m.rConn.AddChain(chain)
|
|
||||||
|
|
||||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
|
||||||
|
|
||||||
return chain
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
|
||||||
polAccept := nftables.ChainPolicyAccept
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: m.workTable,
|
|
||||||
Hooknum: hookNum,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Policy: &polAccept,
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.rConn.AddChain(chain)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: to,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
|
||||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
|
||||||
rawIP := ipToBytes(ip, m.af)
|
|
||||||
if err != nil {
|
|
||||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("get set name: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ipsetStore.newIpset(ipset.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
|
|
||||||
return ipset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
|
||||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ipsetStore.AddIpToSet(ipset.Name, ip)
|
|
||||||
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createSet in given table by name
|
|
||||||
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
|
|
||||||
ipset := &nftables.Set{
|
|
||||||
Name: name,
|
|
||||||
Table: table,
|
|
||||||
Dynamic: true,
|
|
||||||
KeyType: m.af.setKeyType,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
|
||||||
return nil, fmt.Errorf("create set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush created set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) flushWithBackoff() (err error) {
|
|
||||||
backoff := 4
|
|
||||||
backoffTime := 1000 * time.Millisecond
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to flush nftables: %v", err)
|
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error("failed to flush nftables, retrying...")
|
|
||||||
if i == backoff-1 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
time.Sleep(backoffTime)
|
|
||||||
backoffTime *= 2
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
|
||||||
if m.workTable == nil || chain == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
list, err := m.rConn.GetRules(m.workTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range list {
|
|
||||||
if len(rule.UserData) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
split := bytes.Split(rule.UserData, []byte(" "))
|
|
||||||
r, ok := m.rules[string(split[0])]
|
|
||||||
if ok {
|
|
||||||
if mangle {
|
|
||||||
*r.mangleRule = *rule
|
|
||||||
} else {
|
|
||||||
*r.nftRule = *rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
|
||||||
rulesetID := ":" + string(proto) + ":"
|
|
||||||
if sPort != nil {
|
|
||||||
rulesetID += sPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
if dPort != nil {
|
|
||||||
rulesetID += dPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
rulesetID += strconv.Itoa(int(action))
|
|
||||||
if ipset == nil {
|
|
||||||
return "ip:" + ip.String() + rulesetID
|
|
||||||
}
|
|
||||||
return "set:" + ipset.Name + rulesetID
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
copy(b, n+"\x00")
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
|
||||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
|
||||||
if af.addrLen == 4 {
|
|
||||||
return ip.To4()
|
|
||||||
}
|
|
||||||
return ip.To16()
|
|
||||||
}
|
|
||||||
|
|
||||||
880
client/firewall/nftables/chains_linux.go
Normal file
880
client/firewall/nftables/chains_linux.go
Normal file
@@ -0,0 +1,880 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) createContainers() error {
|
||||||
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingFw,
|
||||||
|
Table: r.workTable,
|
||||||
|
})
|
||||||
|
|
||||||
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingNat,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: &prio,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingRdr,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameManglePostrouting,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameMangleForward,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
|
||||||
|
r.addPostroutingRules()
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
|
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel routing opens both INPUT and FORWARD.
|
||||||
|
if err := r.openInterface(true); err != nil {
|
||||||
|
log.Errorf("failed to open interface in foreign chains: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *family) setupDataPlaneMark() error {
|
||||||
|
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||||
|
return errors.New("no mangle chains found")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctNew := getCtNewExprs()
|
||||||
|
preExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
preExprs = append(preExprs, ctNew...)
|
||||||
|
preExprs = append(preExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
preNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
|
Exprs: preExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(preNftRule)
|
||||||
|
|
||||||
|
postExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postExprs = append(postExprs, ctNew...)
|
||||||
|
postExprs = append(postExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
postNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePostrouting],
|
||||||
|
Exprs: postExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(postNftRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openInterface adds passthrough accept rules for the NetBird interface to the
|
||||||
|
// kernel's filter table and external chains so they don't drop our traffic.
|
||||||
|
// includeForward also opens the FORWARD chains (kernel routing); when false only
|
||||||
|
// INPUT is opened, which is all the userspace router needs since it never
|
||||||
|
// forwards in the kernel.
|
||||||
|
func (r *family) openInterface(includeForward bool) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.acceptFilterTableRules(includeForward); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptExternalChainsRules(includeForward); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) acceptFilterTableRules(includeForward bool) error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := "iptables"
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Debugf("Used %s to add accept input/forward rules", fw)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Try iptables first and fallback to nftables if iptables is not available.
|
||||||
|
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||||
|
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
|
fw = "nftables"
|
||||||
|
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.acceptFilterRulesIptables(ipt, includeForward); err != nil {
|
||||||
|
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||||
|
fw = "nftables"
|
||||||
|
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables, includeForward bool) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if includeForward {
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||||
|
} else {
|
||||||
|
log.Debugf("added iptables forward rule: %v", rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputRule := r.getAcceptInputRule()
|
||||||
|
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||||
|
} else {
|
||||||
|
log.Debugf("added iptables input rule: %v", inputRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) getAcceptForwardRules() [][]string {
|
||||||
|
intf := r.wgIface.Name()
|
||||||
|
return [][]string{
|
||||||
|
{"-i", intf, "-j", "ACCEPT"},
|
||||||
|
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) getAcceptInputRule() []string {
|
||||||
|
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||||
|
// This is used when iptables is not available.
|
||||||
|
func (r *family) acceptFilterRulesNftables(table *nftables.Table, includeForward bool) error {
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
if includeForward {
|
||||||
|
forwardChain := &nftables.Chain{
|
||||||
|
Name: chainNameForward,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertForwardAcceptRules(forwardChain, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
inputChain := &nftables.Chain{
|
||||||
|
Name: chainNameInput,
|
||||||
|
Table: table,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookInput,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
}
|
||||||
|
r.insertInputAcceptRule(inputChain, intf)
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||||
|
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||||
|
func (r *family) acceptExternalChainsRules(includeForward bool) error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
for _, chain := range chains {
|
||||||
|
r.applyExternalChainAccept(chain, intf, includeForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush external chain rules: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte, includeForward bool) {
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||||
|
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||||
|
|
||||||
|
switch *chain.Hooknum {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
if includeForward {
|
||||||
|
r.insertForwardAcceptRules(chain, intf)
|
||||||
|
}
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
r.insertInputAcceptRule(chain, intf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||||
|
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.insertForwardIifRule(chain, intf, existing)
|
||||||
|
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||||
|
if existing[userDataAcceptForwardRuleIif] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
|
},
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||||
|
if existing[userDataAcceptForwardRuleOif] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||||
|
}
|
||||||
|
r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: append(exprs, getEstablishedExprs(2)...),
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||||
|
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing[userDataAcceptInputRule] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
|
},
|
||||||
|
UserData: []byte(userDataAcceptInputRule),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
||||||
|
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
||||||
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list rules: %w", err)
|
||||||
|
}
|
||||||
|
present := map[string]bool{}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
present[string(rule.UserData)] = true
|
||||||
|
}
|
||||||
|
return present, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
||||||
|
switch string(userData) {
|
||||||
|
case userDataAcceptForwardRuleIif,
|
||||||
|
userDataAcceptForwardRuleOif,
|
||||||
|
userDataAcceptInputRule:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeAcceptFilterRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeFilterTableRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeExternalChainsRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeFilterTableRules() error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||||
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||||
|
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||||
|
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||||
|
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name != table.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||||
|
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||||
|
// ensuring cleanup works even after a crash or if chains changed.
|
||||||
|
func (r *family) removeExternalChainsRules() error {
|
||||||
|
chains := r.findExternalChains()
|
||||||
|
if len(chains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, chain := range chains {
|
||||||
|
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove rules from external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||||
|
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||||
|
func (r *family) findExternalChains() []*nftables.Chain {
|
||||||
|
var chains []*nftables.Chain
|
||||||
|
|
||||||
|
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||||
|
|
||||||
|
for _, family := range families {
|
||||||
|
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("list chains for family %d: %v", family, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range allChains {
|
||||||
|
if r.isExternalChain(chain) {
|
||||||
|
chains = append(chains, chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return chains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) isExternalChain(chain *nftables.Chain) bool {
|
||||||
|
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||||
|
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||||
|
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||||
|
if chain.Table.Name == firewalldTableName {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||||
|
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Type != nftables.ChainTypeFilter {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain.Hooknum == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesTable(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputRule := r.getAcceptInputRule()
|
||||||
|
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (r *family) Flush() error {
|
||||||
|
if err := r.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// queuePreroutingRule builds the prerouting mangle rule that marks
|
||||||
|
// redirected traffic and queues it on the connection without flushing,
|
||||||
|
// so the caller can commit it in the same transaction as the rule it
|
||||||
|
// pairs with. Returns nil when the prerouting chain is absent, in which
|
||||||
|
// case nothing is queued.
|
||||||
|
func (r *family) queuePreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||||
|
if r.chainPrerouting == nil {
|
||||||
|
log.Warn("prerouting chain is not created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingExprs := slices.Clone(expressions)
|
||||||
|
|
||||||
|
// interface
|
||||||
|
preroutingExprs = append([]expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}, preroutingExprs...)
|
||||||
|
|
||||||
|
// local destination and mark
|
||||||
|
preroutingExprs = append(preroutingExprs,
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chainPrerouting,
|
||||||
|
Exprs: preroutingExprs,
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createDefaultChains() (err error) {
|
||||||
|
// chainNameInputRules
|
||||||
|
chain := r.createChain(chainNameInputRules)
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
r.chainInputRules = chain
|
||||||
|
|
||||||
|
// netbird-acl-input-filter
|
||||||
|
// type filter hook input priority filter; policy accept;
|
||||||
|
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
|
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||||
|
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// netbird-acl-forward-filter
|
||||||
|
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
|
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
|
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||||
|
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||||
|
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||||
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
|
// netbird peer IP.
|
||||||
|
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
|
r.chainPrerouting = r.chains[chainNameManglePrerouting]
|
||||||
|
|
||||||
|
r.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
|
r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: chainFwFilter,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: r.routingFwChainName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: chainFwFilter,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createChain(name string) *nftables.Chain {
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: name,
|
||||||
|
Table: r.workTable,
|
||||||
|
}
|
||||||
|
|
||||||
|
chain = r.conn.AddChain(chain)
|
||||||
|
|
||||||
|
insertReturnTrafficRule(r.conn, r.workTable, chain)
|
||||||
|
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||||
|
polAccept := nftables.ChainPolicyAccept
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: name,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: hookNum,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Policy: &polAccept,
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.AddChain(chain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
|
}
|
||||||
|
_ = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: to,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to flush nftables: %v", err)
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime *= 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||||
|
if r.workTable == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := r.conn.GetRules(r.workTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if mangle {
|
||||||
|
if pr.mangleRule != nil {
|
||||||
|
*pr.mangleRule = *rule
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*pr.nftRule = *rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
565
client/firewall/nftables/dnat_linux.go
Normal file
565
client/firewall/nftables/dnat_linux.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/google/nftables/xt"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request forwarding once the rule is about to be installed, releasing
|
||||||
|
// it if a later step fails so the refcount tracks the real rules.
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
|
||||||
|
r.releaseForwarding()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addDnatMasq(rule, protoNum, ruleID); err != nil {
|
||||||
|
r.releaseForwarding()
|
||||||
|
delete(r.rules, ruleID+dnatSuffix)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||||
|
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||||
|
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||||
|
// TODO: find chains with drop policies and add rules there
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
r.releaseForwarding()
|
||||||
|
return nil, fmt.Errorf("flush rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
||||||
|
dnatExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
portExprs, err := r.applyPort(&rule.DestinationPort, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination port: %w", err)
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, portExprs...)
|
||||||
|
|
||||||
|
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||||
|
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||||
|
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||||
|
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||||
|
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||||
|
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(r.af.tableFamily),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: regProtoMin,
|
||||||
|
RegProtoMax: regProtoMax,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingRdr],
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleID + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
switch {
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
return r.handlePortRange(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
return r.handleAddressOnly(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
return r.handleSinglePort(rule)
|
||||||
|
default:
|
||||||
|
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 3,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Target{
|
||||||
|
Name: "DNAT",
|
||||||
|
Rev: 2,
|
||||||
|
Info: &xt.NatRange2{
|
||||||
|
NatRange: xt.NatRange{
|
||||||
|
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||||
|
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MinPort: rule.TranslatedPort.Values[0],
|
||||||
|
MaxPort: rule.TranslatedPort.Values[1],
|
||||||
|
},
|
||||||
|
BasePort: rule.DestinationPort.Values[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
natTable := &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: r.af.tableFamily,
|
||||||
|
}
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: natTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: natTable,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
},
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleID + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
||||||
|
portExprs, err := r.applyPort(&rule.TranslatedPort, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply translated port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
masqExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: r.af.dstAddrOffset,
|
||||||
|
Len: r.af.addrLen,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
masqExprs = append(masqExprs, portExprs...)
|
||||||
|
masqExprs = append(masqExprs, &expr.Masq{})
|
||||||
|
|
||||||
|
masqRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: masqExprs,
|
||||||
|
UserData: []byte(ruleID + snatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(masqRule)
|
||||||
|
r.rules[ruleID+snatSuffix] = masqRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
var found bool
|
||||||
|
|
||||||
|
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||||
|
found = true
|
||||||
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
|
||||||
|
delete(r.rules, ruleID+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||||
|
found = true
|
||||||
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
|
||||||
|
delete(r.rules, ruleID+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsFlush {
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr != nil {
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.rules, ruleID+dnatSuffix)
|
||||||
|
delete(r.rules, ruleID+snatSuffix)
|
||||||
|
|
||||||
|
// Release once, only if the rule was present and removed.
|
||||||
|
if found {
|
||||||
|
r.releaseForwarding()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// releaseForwarding drops one IP forwarding reference, logging any error.
|
||||||
|
func (r *family) releaseForwarding() {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("release IP forwarding: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := r.af.protoNum(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 2,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 3,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 3,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bits := 32
|
||||||
|
if localAddr.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||||
|
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: localAddr.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||||
|
},
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(r.af.tableFamily),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: 2,
|
||||||
|
RegProtoMax: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingRdr],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleID),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
|
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||||
|
func (r *family) ensureNATOutputChain() error {
|
||||||
|
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameNATOutput,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookOutput,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
delete(r.chains, chainNameNATOutput)
|
||||||
|
return fmt.Errorf("create NAT output chain: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
|
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.ensureNATOutputChain(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := r.af.protoNum(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 2,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bits := 32
|
||||||
|
if localAddr.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||||
|
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: localAddr.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||||
|
},
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(r.af.tableFamily),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameNATOutput],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleID),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
|
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||||
|
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
249
client/firewall/nftables/family_linux.go
Normal file
249
client/firewall/nftables/family_linux.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
|
tableRaw = "raw"
|
||||||
|
tableSecurity = "security"
|
||||||
|
|
||||||
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
|
chainNameNATOutput = "netbird-nat-output"
|
||||||
|
chainNameForward = "FORWARD"
|
||||||
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
|
// Peer ACL chain names.
|
||||||
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
|
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||||
|
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||||
|
|
||||||
|
flushError = "flush: %w"
|
||||||
|
|
||||||
|
firewalldTableName = "firewalld"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
userDataAcceptInputRule = "inputaccept"
|
||||||
|
|
||||||
|
dnatSuffix firewall.RuleID = "_dnat"
|
||||||
|
snatSuffix firewall.RuleID = "_snat"
|
||||||
|
|
||||||
|
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv4TCPHeaderSize = 40
|
||||||
|
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||||
|
ipv6TCPHeaderSize = 60
|
||||||
|
|
||||||
|
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||||
|
maxPrefixesSet = 1500
|
||||||
|
refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type setInput struct {
|
||||||
|
set firewall.Set
|
||||||
|
prefixes []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// family holds the per-address-family nftables state. One instance
|
||||||
|
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||||
|
// single family; the top-level Manager owns one for v4 and another
|
||||||
|
// for v6. The name predates the peer-ACL absorption; it's effectively
|
||||||
|
// the per-family backend now.
|
||||||
|
type family struct {
|
||||||
|
conn *nftables.Conn
|
||||||
|
workTable *nftables.Table
|
||||||
|
filterTable *nftables.Table
|
||||||
|
chains map[string]*nftables.Chain
|
||||||
|
|
||||||
|
// filters holds peer + route filter rules keyed by content hash.
|
||||||
|
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||||
|
filters map[firewall.RuleID]*Rule
|
||||||
|
|
||||||
|
// rules holds NAT, DNAT, and external accept rules (auxiliary
|
||||||
|
// plumbing that isn't a filter rule).
|
||||||
|
rules map[firewall.RuleID]*nftables.Rule
|
||||||
|
|
||||||
|
// Peer ACL chain handles.
|
||||||
|
chainInputRules *nftables.Chain
|
||||||
|
chainPrerouting *nftables.Chain
|
||||||
|
routingFwChainName string
|
||||||
|
|
||||||
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
|
af addrFamily
|
||||||
|
wgIface iFaceMapper
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
|
legacyManagement bool
|
||||||
|
mtu uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) *family {
|
||||||
|
r := &family{
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
workTable: workTable,
|
||||||
|
chains: make(map[string]*nftables.Chain),
|
||||||
|
filters: make(map[firewall.RuleID]*Rule),
|
||||||
|
rules: make(map[firewall.RuleID]*nftables.Rule),
|
||||||
|
routingFwChainName: chainNameRoutingFw,
|
||||||
|
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||||
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
|
mtu: mtu,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ipsetCounter = refcounter.New(
|
||||||
|
r.createIpSet,
|
||||||
|
r.deleteIpSet,
|
||||||
|
)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r.filterTable, err = r.loadFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("ip filter table not found: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) init(workTable *nftables.Table) error {
|
||||||
|
r.workTable = workTable
|
||||||
|
|
||||||
|
if err := r.removeAcceptFilterRules(); err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.createContainers(); err != nil {
|
||||||
|
return fmt.Errorf("create containers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.createDefaultChains(); err != nil {
|
||||||
|
return fmt.Errorf("create default acl chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset cleans existing nftables filter table rules from the system
|
||||||
|
func (r *family) Reset() error {
|
||||||
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeAcceptFilterRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) loadFilterTable() (*nftables.Table, error) {
|
||||||
|
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list tables: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
if table.Name == "filter" {
|
||||||
|
return table, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errFilterTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func hookName(hook *nftables.ChainHook) string {
|
||||||
|
if hook == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
switch *hook {
|
||||||
|
case *nftables.ChainHookForward:
|
||||||
|
return chainNameForward
|
||||||
|
case *nftables.ChainHookInput:
|
||||||
|
return chainNameInput
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("hook(%d)", *hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func familyName(family nftables.TableFamily) string {
|
||||||
|
switch family {
|
||||||
|
case nftables.TableFamilyIPv4:
|
||||||
|
return "ip"
|
||||||
|
case nftables.TableFamilyIPv6:
|
||||||
|
return "ip6"
|
||||||
|
case nftables.TableFamilyINet:
|
||||||
|
return "inet"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("family(%d)", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) iptablesProto() iptables.Protocol {
|
||||||
|
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||||
|
return iptables.ProtocolIPv6
|
||||||
|
}
|
||||||
|
return iptables.ProtocolIPv4
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[firewall.RuleID]*nftables.Rule)
|
||||||
|
for _, chain := range r.chains {
|
||||||
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 {
|
||||||
|
newRules[firewall.RuleID(rule.UserData)] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
512
client/firewall/nftables/filter_linux.go
Normal file
512
client/firewall/nftables/filter_linux.go
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddFilterRule installs one nftables packet-filter rule. With
|
||||||
|
// destination empty the rule goes to the peer ACL input chain plus a
|
||||||
|
// paired prerouting mangle rule for the redirect mark. With
|
||||||
|
// destination set (prefix or named set) it goes to the route ACL
|
||||||
|
// forward chain. Multi-source rules collapse to one nftables rule
|
||||||
|
// backed by the shared refcounted hash:net set.
|
||||||
|
func (r *family) AddFilterRule(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
isRoute := !destination.IsZero()
|
||||||
|
|
||||||
|
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||||
|
if existing, ok := r.filters[ruleID]; ok {
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
if isRoute {
|
||||||
|
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
|
||||||
|
} else {
|
||||||
|
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(srcExprs)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mainExprs := slices.Clone(exprs)
|
||||||
|
verdict := expr.VerdictAccept
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
verdict = expr.VerdictDrop
|
||||||
|
}
|
||||||
|
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
|
||||||
|
|
||||||
|
chain := r.chainInputRules
|
||||||
|
if isRoute {
|
||||||
|
chain = r.chains[chainNameRoutingFw]
|
||||||
|
}
|
||||||
|
|
||||||
|
userData := []byte(ruleID)
|
||||||
|
|
||||||
|
// Build the paired prerouting mangle rule before flushing so both
|
||||||
|
// rules commit in one transaction. An anonymous port set binds to
|
||||||
|
// exactly one rule, so the mangle rule needs its own expression list
|
||||||
|
// with fresh sets, not a clone of the main rule's. Guard on the
|
||||||
|
// prerouting chain first: building the expressions queues the port
|
||||||
|
// set, so skipping the build when there is no chain to bind it to
|
||||||
|
// keeps an unbound set out of the connection batch.
|
||||||
|
var mangleRule *nftables.Rule
|
||||||
|
if !isRoute && r.chainPrerouting != nil {
|
||||||
|
mangleExprs, err := r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(exprs)
|
||||||
|
return nil, fmt.Errorf("build mangle rule: %w", err)
|
||||||
|
}
|
||||||
|
mangleRule = r.queuePreroutingRule(mangleExprs, userData)
|
||||||
|
}
|
||||||
|
|
||||||
|
nftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: mainExprs,
|
||||||
|
UserData: userData,
|
||||||
|
}
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
nftRule = r.conn.InsertRule(nftRule)
|
||||||
|
} else {
|
||||||
|
nftRule = r.conn.AddRule(nftRule)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
r.dropNetworkMatch(exprs)
|
||||||
|
return nil, fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &Rule{
|
||||||
|
nftRule: nftRule,
|
||||||
|
mangleRule: mangleRule,
|
||||||
|
sources: sources,
|
||||||
|
id: ruleID,
|
||||||
|
}
|
||||||
|
r.filters[ruleID] = rule
|
||||||
|
|
||||||
|
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
|
||||||
|
sources, destination, proto, sPort, dPort, action)
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
|
||||||
|
// IP-header protocol byte read via Payload, then source, then ports
|
||||||
|
// (no counter), matching the historical peer shape so per-rule kernel
|
||||||
|
// state is identical to pre-unification.
|
||||||
|
func (r *family) buildPeerFilterExprs(
|
||||||
|
srcExprs []expr.Any,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
if proto != firewall.ProtocolALL {
|
||||||
|
protoNum, err := r.af.protoNum(proto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: r.af.protoOffset,
|
||||||
|
Len: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, srcExprs...)
|
||||||
|
|
||||||
|
portExprs, err := r.applyPorts(sPort, dPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
exprs = append(exprs, portExprs...)
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
|
||||||
|
// source, then destination, then optional proto/ports, then a counter.
|
||||||
|
func (r *family) buildRouteFilterExprs(
|
||||||
|
srcExprs []expr.Any,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
exprs := append([]expr.Any{}, srcExprs...)
|
||||||
|
|
||||||
|
destExprs, err := r.applyNetwork(destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, destExprs...)
|
||||||
|
|
||||||
|
if proto != firewall.ProtocolALL {
|
||||||
|
protoNum, err := r.af.protoNum(proto)
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(destExprs)
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||||
|
)
|
||||||
|
|
||||||
|
portExprs, err := r.applyPorts(sPort, dPort)
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(destExprs)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
exprs = append(exprs, portExprs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, &expr.Counter{})
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) hasRule(id firewall.RuleID) bool {
|
||||||
|
_, ok := r.filters[id]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||||
|
_, ok := r.rules[id+dnatSuffix]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilterRule removes a previously installed filter rule. Source
|
||||||
|
// set references are recovered from the stored rule's expressions via
|
||||||
|
// findSets and dropped from the shared refcounter.
|
||||||
|
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||||
|
ruleID := rule.ID()
|
||||||
|
pr, ok := r.filters[ruleID]
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("filter rule %s not found", ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A freshly added rule carries no handle until it is read back from
|
||||||
|
// the kernel, and Flush only refreshes the peer chains. Pull live
|
||||||
|
// handles for this rule's chain before deciding it is stale so route
|
||||||
|
// rules (which Flush never refreshes) can actually be deleted. A
|
||||||
|
// refresh failure aborts the delete without touching tracking state,
|
||||||
|
// so the caller can retry while the rule may still exist in the kernel.
|
||||||
|
if pr.nftRule.Handle == 0 {
|
||||||
|
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
|
||||||
|
return fmt.Errorf("refresh handles for chain %s: %w", pr.nftRule.Chain.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Refresh the mangle handle independently: the main rule's handle can
|
||||||
|
// be populated while the prerouting refresh during Flush failed, and
|
||||||
|
// gating the mangle refresh on the main handle would leak the mangle
|
||||||
|
// rule on delete.
|
||||||
|
if pr.mangleRule != nil && pr.mangleRule.Handle == 0 {
|
||||||
|
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||||
|
return fmt.Errorf("refresh mangle handles: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pr.nftRule.Handle == 0 {
|
||||||
|
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||||
|
delete(r.filters, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(pr.nftRule); err != nil {
|
||||||
|
log.Errorf("queue rule delete: %v", err)
|
||||||
|
}
|
||||||
|
if pr.mangleRule != nil {
|
||||||
|
if err := r.conn.DelRule(pr.mangleRule); err != nil {
|
||||||
|
log.Errorf("queue mangle rule delete: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||||
|
delete(r.filters, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
|
||||||
|
if r.ipsetCounter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sets := findSets(rule)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dropNetworkMatch undoes whatever the source/destination match
|
||||||
|
// reserved. Safe to call when the spec is empty or holds only inline
|
||||||
|
// matchers.
|
||||||
|
func (r *family) dropNetworkMatch(exprs []expr.Any) {
|
||||||
|
if r.ipsetCounter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, e := range exprs {
|
||||||
|
lookup, ok := e.(*expr.Lookup)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
|
||||||
|
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) applyNetwork(
|
||||||
|
network firewall.Network,
|
||||||
|
setPrefixes []netip.Prefix,
|
||||||
|
isSource bool,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
if network.IsSet() {
|
||||||
|
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||||
|
if err != nil {
|
||||||
|
side := "destination"
|
||||||
|
if isSource {
|
||||||
|
side = "source"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("%s set: %w", side, err)
|
||||||
|
}
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPort builds the transport-header port match. A single value
|
||||||
|
// compares directly, a range uses a range expression, and multiple
|
||||||
|
// values go through an anonymous constant set: consecutive cmp
|
||||||
|
// expressions AND together, so chained equality comparisons could
|
||||||
|
// never match more than one port. The set is queued on the
|
||||||
|
// connection and committed by the caller's flush together with the
|
||||||
|
// rule that binds it.
|
||||||
|
func (r *family) applyPort(port *firewall.Port, isSource bool) ([]expr.Any, error) {
|
||||||
|
if port == nil || len(port.Values) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dst port
|
||||||
|
offset := uint32(2)
|
||||||
|
if isSource {
|
||||||
|
// src port
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case port.IsRange && len(port.Values) == 2:
|
||||||
|
exprs = append(exprs, &expr.Range{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
|
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
|
})
|
||||||
|
case len(port.Values) == 1:
|
||||||
|
exprs = append(exprs, &expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
lookup, err := r.anonymousPortSet(port.Values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
exprs = append(exprs, lookup)
|
||||||
|
}
|
||||||
|
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// anonymousPortSet queues an anonymous constant set holding the given
|
||||||
|
// ports on the connection and returns a lookup against it. The set is
|
||||||
|
// committed by the caller's flush together with the rule that binds it.
|
||||||
|
func (r *family) anonymousPortSet(values []uint16) (*expr.Lookup, error) {
|
||||||
|
set := &nftables.Set{
|
||||||
|
Anonymous: true,
|
||||||
|
Constant: true,
|
||||||
|
Table: r.workTable,
|
||||||
|
KeyType: nftables.TypeInetService,
|
||||||
|
}
|
||||||
|
elements := make([]nftables.SetElement, 0, len(values))
|
||||||
|
for _, p := range values {
|
||||||
|
elements = append(elements, nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)})
|
||||||
|
}
|
||||||
|
if err := r.conn.AddSet(set, elements); err != nil {
|
||||||
|
return nil, fmt.Errorf("add anonymous port set: %w", err)
|
||||||
|
}
|
||||||
|
return &expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetID: set.ID,
|
||||||
|
SetName: set.Name,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPorts builds the source then destination port matches.
|
||||||
|
func (r *family) applyPorts(sPort, dPort *firewall.Port) ([]expr.Any, error) {
|
||||||
|
sPortExprs, err := r.applyPort(sPort, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dPortExprs, err := r.applyPort(dPort, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply destination port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(sPortExprs, dPortExprs...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixMatchExprs is the family-aware match sequence for a CIDR
|
||||||
|
// prefix. /0 returns nil; a host prefix (full bit length for the
|
||||||
|
// family) skips the bitwise step since the mask is all-ones. Shared
|
||||||
|
// between family and aclManager so both treat single prefixes
|
||||||
|
// identically.
|
||||||
|
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
|
offset := af.dstAddrOffset
|
||||||
|
if isSource {
|
||||||
|
offset = af.srcAddrOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
ones := prefix.Bits()
|
||||||
|
if ones == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: af.addrLen,
|
||||||
|
}
|
||||||
|
cmp := &expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: prefix.Masked().Addr().AsSlice(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ones == af.totalBits {
|
||||||
|
return []expr.Any{payload, cmp}
|
||||||
|
}
|
||||||
|
|
||||||
|
mask := net.CIDRMask(ones, af.totalBits)
|
||||||
|
xor := make([]byte, af.addrLen)
|
||||||
|
return []expr.Any{
|
||||||
|
payload,
|
||||||
|
&expr.Bitwise{
|
||||||
|
DestRegister: 1,
|
||||||
|
SourceRegister: 1,
|
||||||
|
Len: af.addrLen,
|
||||||
|
Mask: mask,
|
||||||
|
Xor: xor,
|
||||||
|
},
|
||||||
|
cmp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCtNewExprs() []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||||
|
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||||
|
// single prefix inline, or an ipset for multiple sources.
|
||||||
|
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||||
|
switch {
|
||||||
|
case len(sources) == 0:
|
||||||
|
return firewall.Network{}
|
||||||
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
|
return firewall.Network{}
|
||||||
|
case len(sources) == 1:
|
||||||
|
return firewall.Network{Prefix: sources[0]}
|
||||||
|
default:
|
||||||
|
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ifname(n string) []byte {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
copy(b, n+"\x00")
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// findSets scans an nftables rule's expressions for expr.Lookup and
|
||||||
|
// returns the named sets in occurrence order. Used at delete time to
|
||||||
|
// drop ipsetCounter references; peer and route ACLs go through it.
|
||||||
|
func findSets(rule *nftables.Rule) []string {
|
||||||
|
var sets []string
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
|
sets = append(sets, lookup.SetName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
|
}
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestInterfaceAllowerInputOnly verifies the userspace-mode allower opens the
|
||||||
|
// interface on the INPUT hook of foreign chains only (not FORWARD, since the
|
||||||
|
// userspace router never forwards in the kernel), creates no netbird work
|
||||||
|
// table, and removes its rules on Close.
|
||||||
|
func TestInterfaceAllowerInputOnly(t *testing.T) {
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
t.Skip("root required")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.False(t, ipTableExists(t, getTableName()), "precondition: no stale netbird table")
|
||||||
|
|
||||||
|
conn := &nftables.Conn{}
|
||||||
|
extTable := conn.AddTable(&nftables.Table{Name: "nbtest_extchains", Family: nftables.TableFamilyINet})
|
||||||
|
inputChain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "ext_input", Table: extTable,
|
||||||
|
Hooknum: nftables.ChainHookInput, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
forwardChain := conn.AddChain(&nftables.Chain{
|
||||||
|
Name: "ext_forward", Table: extTable,
|
||||||
|
Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
require.NoError(t, conn.Flush(), "create external table and chains")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
c := &nftables.Conn{}
|
||||||
|
c.DelTable(extTable)
|
||||||
|
_ = c.Flush()
|
||||||
|
})
|
||||||
|
|
||||||
|
allower, err := NewInterfaceAllower(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "create allower")
|
||||||
|
require.NoError(t, allower.Apply(), "apply")
|
||||||
|
|
||||||
|
require.True(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
|
||||||
|
"external INPUT chain should get the accept rule")
|
||||||
|
require.Len(t, listRules(t, extTable, forwardChain), 0,
|
||||||
|
"external FORWARD chain must not be opened in userspace mode")
|
||||||
|
require.False(t, ipTableExists(t, getTableName()),
|
||||||
|
"allower must not create a netbird work table")
|
||||||
|
|
||||||
|
require.NoError(t, allower.Close(), "close")
|
||||||
|
require.False(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
|
||||||
|
"accept rule should be removed on close")
|
||||||
|
}
|
||||||
|
|
||||||
|
func listRules(t *testing.T, table *nftables.Table, chain *nftables.Chain) []*nftables.Rule {
|
||||||
|
t.Helper()
|
||||||
|
c := &nftables.Conn{}
|
||||||
|
rules, err := c.GetRules(table, chain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func chainHasUserData(t *testing.T, table *nftables.Table, chain *nftables.Chain, ud string) bool {
|
||||||
|
for _, r := range listRules(t, table, chain) {
|
||||||
|
if bytes.Equal(r.UserData, []byte(ud)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipTableExists(t *testing.T, name string) bool {
|
||||||
|
t.Helper()
|
||||||
|
c := &nftables.Conn{}
|
||||||
|
for _, fam := range []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyIPv6} {
|
||||||
|
tbls, err := c.ListTablesOfFamily(fam)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, tb := range tbls {
|
||||||
|
if tb.Name == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
107
client/firewall/nftables/interface_allower_linux.go
Normal file
107
client/firewall/nftables/interface_allower_linux.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InterfaceAllower opens the NetBird interface in the kernel's filter table and
|
||||||
|
// external chains and keeps them reconciled via a netlink monitor, so the host
|
||||||
|
// firewall doesn't drop traffic the NetBird firewall handles. It is used by the
|
||||||
|
// userspace firewall, where routing happens in the forwarder, so only INPUT is
|
||||||
|
// opened (the userspace router never forwards in the kernel).
|
||||||
|
//
|
||||||
|
// It owns its own families/connection and never creates a netbird work table.
|
||||||
|
// firewalld trust is handled by the caller, not here. Its operations are serial
|
||||||
|
// (Apply before the monitor starts; reconciles run on the single monitor
|
||||||
|
// goroutine; Close stops the monitor before removing), so it needs no locking.
|
||||||
|
//
|
||||||
|
// TODO: this opens nftables and the iptables-nft filter table (detected via
|
||||||
|
// nft), but not a legacy-iptables ruleset running in parallel with nftables.
|
||||||
|
// Such a host would keep its legacy filter chains closed for the interface.
|
||||||
|
type InterfaceAllower struct {
|
||||||
|
family4 *family
|
||||||
|
family6 *family
|
||||||
|
extMonitor *externalChainMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInterfaceAllower builds an allower for the given interface. It returns an
|
||||||
|
// error when nftables is unavailable (e.g. an iptables-legacy host), so the
|
||||||
|
// caller can fall back to firewalld trust.
|
||||||
|
func NewInterfaceAllower(wgIface iFaceMapper, mtu uint16) (*InterfaceAllower, error) {
|
||||||
|
tableName := getTableName()
|
||||||
|
|
||||||
|
family4 := newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}, wgIface, mtu)
|
||||||
|
|
||||||
|
// Probe nftables availability before committing to this backend.
|
||||||
|
if _, err := family4.conn.ListChainsOfTableFamily(nftables.TableFamilyINet); err != nil {
|
||||||
|
return nil, fmt.Errorf("nftables not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &InterfaceAllower{family4: family4}
|
||||||
|
|
||||||
|
if wgIface.Address().HasIPv6() {
|
||||||
|
a.family6 = newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}, wgIface, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.extMonitor = newExternalChainMonitor(a)
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply opens the interface (INPUT only) in the foreign filter chains and starts
|
||||||
|
// reconciling them on nftables changes.
|
||||||
|
func (a *InterfaceAllower) Apply() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, f := range a.families() {
|
||||||
|
// Remove any stale accepts first so a prior unclean exit (e.g. SIGKILL,
|
||||||
|
// where Close never ran) is recovered deterministically rather than
|
||||||
|
// accumulating duplicate rules on the iptables filter table.
|
||||||
|
if err := f.removeAcceptFilterRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("clean stale accept rules: %w", err))
|
||||||
|
}
|
||||||
|
if err := f.openInterface(false); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a.extMonitor.start()
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// families returns the configured address families (v4, and v6 when present).
|
||||||
|
func (a *InterfaceAllower) families() []*family {
|
||||||
|
families := []*family{a.family4}
|
||||||
|
if a.family6 != nil {
|
||||||
|
families = append(families, a.family6)
|
||||||
|
}
|
||||||
|
return families
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconcileExternalChains re-applies the INPUT accepts to external chains. It
|
||||||
|
// implements externalChainReconciler for the monitor.
|
||||||
|
func (a *InterfaceAllower) reconcileExternalChains() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, f := range a.families() {
|
||||||
|
if err := f.acceptExternalChainsRules(false); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the monitor and removes the accept rules.
|
||||||
|
func (a *InterfaceAllower) Close() error {
|
||||||
|
a.extMonitor.stop()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, f := range a.families() {
|
||||||
|
if err := f.removeAcceptFilterRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
210
client/firewall/nftables/ipset_linux.go
Normal file
210
client/firewall/nftables/ipset_linux.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||||
|
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||||
|
set: set,
|
||||||
|
prefixes: prefixes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.getIpSetExprs(ref, isSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||||
|
// overlapping prefixes will result in an error, so we need to merge them
|
||||||
|
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||||
|
|
||||||
|
nfset := &nftables.Set{
|
||||||
|
Name: setName,
|
||||||
|
Comment: input.set.Comment(),
|
||||||
|
Table: r.workTable,
|
||||||
|
// required for prefixes
|
||||||
|
Interval: true,
|
||||||
|
KeyType: r.af.setKeyType,
|
||||||
|
}
|
||||||
|
|
||||||
|
elements := r.convertPrefixesToSet(prefixes)
|
||||||
|
nElements := len(elements)
|
||||||
|
|
||||||
|
maxElements := maxPrefixesSet * 2
|
||||||
|
initialElements := elements[:min(maxElements, nElements)]
|
||||||
|
|
||||||
|
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||||
|
|
||||||
|
// The set is committed now. If a later batch fails, destroy it: the
|
||||||
|
// refcounter records nothing on a create-callback error, so it would
|
||||||
|
// otherwise leak, and a partial source set fails-open for deny rules.
|
||||||
|
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
|
||||||
|
if derr := r.deleteIpSet(setName, nfset); derr != nil {
|
||||||
|
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||||
|
return nfset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRemainingElements adds element batches beyond the initial one in
|
||||||
|
// maxElements-sized chunks, flushing each. Called after the set has been
|
||||||
|
// created with its first batch.
|
||||||
|
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
|
||||||
|
nElements := len(elements)
|
||||||
|
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||||
|
subEnd := min(subStart+maxElements, nElements)
|
||||||
|
subElement := elements[subStart:subEnd]
|
||||||
|
nSubPrefixes := len(subElement) / 2
|
||||||
|
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||||
|
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||||
|
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
|
var elements []nftables.SetElement
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||||
|
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||||
|
firstIP := prefix.Addr()
|
||||||
|
|
||||||
|
// For a /0 the last address is the broadcast and its Next() overflows
|
||||||
|
// to an invalid Addr with an empty key, so wrap to the zero address,
|
||||||
|
// which nftables reads as the open end of a full-range interval.
|
||||||
|
var lastKey []byte
|
||||||
|
if prefix.Bits() == 0 {
|
||||||
|
lastKey = make([]byte, r.af.addrLen)
|
||||||
|
} else {
|
||||||
|
lastKey = calculateLastIP(prefix).Next().AsSlice()
|
||||||
|
}
|
||||||
|
|
||||||
|
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||||
|
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||||
|
elements = append(elements,
|
||||||
|
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||||
|
nftables.SetElement{Key: lastKey, IntervalEnd: true},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return elements
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
|
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||||
|
masked := prefix.Masked()
|
||||||
|
if masked.Addr().Is4() {
|
||||||
|
hostMask := ^uint32(0) >> masked.Bits()
|
||||||
|
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||||
|
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6: set host bits to all 1s
|
||||||
|
b := masked.Addr().As16()
|
||||||
|
bits := masked.Bits()
|
||||||
|
for i := bits; i < 128; i++ {
|
||||||
|
b[i/8] |= 1 << (7 - i%8)
|
||||||
|
}
|
||||||
|
return netip.AddrFrom16(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert netip.Addr to uint32.
|
||||||
|
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||||
|
b := addr.As4()
|
||||||
|
return binary.BigEndian.Uint32(b[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||||
|
func uint32ToBytes(ip uint32) [4]byte {
|
||||||
|
var b [4]byte
|
||||||
|
binary.BigEndian.PutUint32(b[:], ip)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||||
|
r.conn.DelSet(nfset)
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Deleted unused ipset %s", setName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
|
||||||
|
// interval set reject the batch, so merge them as createIpSet does.
|
||||||
|
prefixes = firewall.MergeIPRanges(prefixes)
|
||||||
|
elements := r.convertPrefixesToSet(prefixes)
|
||||||
|
|
||||||
|
// Add in batches sized like createIpSet so a large update does not
|
||||||
|
// exceed the netlink message size limit.
|
||||||
|
maxElements := maxPrefixesSet * 2
|
||||||
|
for start := 0; start < len(elements); start += maxElements {
|
||||||
|
end := min(start+maxElements, len(elements))
|
||||||
|
if err := r.conn.SetAddElements(nfset, elements[start:end]); err != nil {
|
||||||
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated set %s with %d prefixes", set.HashedName(), len(prefixes))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
|
// dst offset by default
|
||||||
|
offset := r.af.dstAddrOffset
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = r.af.srcAddrOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: r.af.addrLen,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ref.Out.Name,
|
||||||
|
SetID: ref.Out.ID,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
36
client/firewall/nftables/ipset_linux_test.go
Normal file
36
client/firewall/nftables/ipset_linux_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConvertPrefixesToSetWildcard verifies that a /0 prefix produces a
|
||||||
|
// usable interval. The last address of a /0 is the broadcast, whose Next()
|
||||||
|
// overflows to an invalid Addr with an empty key; the IntervalEnd must wrap
|
||||||
|
// to the zero address instead so nftables sees a full-range interval.
|
||||||
|
func TestConvertPrefixesToSetWildcard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
af addrFamily
|
||||||
|
prefix string
|
||||||
|
}{
|
||||||
|
{"IPv4 /0", afIPv4, "0.0.0.0/0"},
|
||||||
|
{"IPv6 /0", afIPv6, "::/0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := &family{af: tt.af}
|
||||||
|
elements := r.convertPrefixesToSet([]netip.Prefix{netip.MustParsePrefix(tt.prefix)})
|
||||||
|
|
||||||
|
require.Len(t, elements, 2, "expected start and interval-end element")
|
||||||
|
assert.False(t, elements[0].IntervalEnd, "first element is the interval start")
|
||||||
|
assert.True(t, elements[1].IntervalEnd, "second element is the interval end")
|
||||||
|
assert.Len(t, elements[1].Key, int(tt.af.addrLen), "interval-end key must be a zero address, not empty")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ipsetStore struct {
|
|
||||||
ipsetReference map[string]int
|
|
||||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
|
||||||
return &ipsetStore{
|
|
||||||
ipsetReference: make(map[string]int),
|
|
||||||
ipsets: make(map[string]map[string]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
|
||||||
r, ok := s.ipsets[ipsetName]
|
|
||||||
return r, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
|
||||||
s.ipsetReference[ipsetName] = 0
|
|
||||||
ipList := make(map[string]struct{})
|
|
||||||
s.ipsets[ipsetName] = ipList
|
|
||||||
return ipList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
|
||||||
delete(s.ipsetReference, ipsetName)
|
|
||||||
delete(s.ipsets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(ipList, ip.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipList[ip.String()] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
|
||||||
ipList, ok := s.ipsets[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, ok = ipList[ip.String()]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
|
||||||
s.ipsetReference[ipsetName]++
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
|
||||||
r, ok := s.ipsetReference[ipsetName]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.ipsetReference[ipsetName]--
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
|
||||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.ipsetReference[ipsetName] == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -16,7 +15,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -45,18 +43,17 @@ type iFaceMapper interface {
|
|||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of nftables firewall. Per-family state (peer ACLs, route
|
||||||
|
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||||
|
// by family and provides the public firewall.Manager surface.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
router *router
|
family4 *family
|
||||||
aclManager *AclManager
|
// IPv6 counterpart, nil when no v6 overlay.
|
||||||
|
family6 *family
|
||||||
// IPv6 counterparts, nil when no v6 overlay
|
|
||||||
router6 *router
|
|
||||||
aclManager6 *AclManager
|
|
||||||
|
|
||||||
notrackOutputChain *nftables.Chain
|
notrackOutputChain *nftables.Chain
|
||||||
notrackPreroutingChain *nftables.Chain
|
notrackPreroutingChain *nftables.Chain
|
||||||
@@ -74,21 +71,10 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
tableName := getTableName()
|
tableName := getTableName()
|
||||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||||
|
|
||||||
var err error
|
m.family4 = newFamily(workTable, wgIface, mtu)
|
||||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create router: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
if wgIface.Address().HasIPv6() {
|
||||||
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
m.createIPv6Components(tableName, wgIface, mtu)
|
||||||
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.extMonitor = newExternalChainMonitor(m)
|
m.extMonitor = newExternalChainMonitor(m)
|
||||||
@@ -96,30 +82,19 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) {
|
||||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||||
|
|
||||||
var err error
|
m.family6 = newFamily(workTable6, wgIface, mtu)
|
||||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 router: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
m.router6.ipFwdState = m.router.ipFwdState
|
m.family6.ipFwdState = m.family4.ipFwdState
|
||||||
|
|
||||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||||
func (m *Manager) hasIPv6() bool {
|
func (m *Manager) hasIPv6() bool {
|
||||||
return m.router6 != nil
|
return m.family6 != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) initIPv6() error {
|
func (m *Manager) initIPv6() error {
|
||||||
@@ -128,12 +103,8 @@ func (m *Manager) initIPv6() error {
|
|||||||
return fmt.Errorf("create v6 work table: %w", err)
|
return fmt.Errorf("create v6 work table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router6.init(workTable6); err != nil {
|
if err := m.family6.init(workTable6); err != nil {
|
||||||
return fmt.Errorf("v6 router init: %w", err)
|
return fmt.Errorf("v6 family init: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.aclManager6.init(workTable6); err != nil {
|
|
||||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -156,19 +127,20 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
// reconcileExternalChains re-applies passthrough accept rules to external
|
// reconcileExternalChains re-applies passthrough accept rules to external
|
||||||
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
||||||
// tables or chains appear (e.g. after firewalld reloads).
|
// tables or chains appear (e.g. after firewalld reloads). Kernel routing opens
|
||||||
|
// both INPUT and FORWARD.
|
||||||
func (m *Manager) reconcileExternalChains() error {
|
func (m *Manager) reconcileExternalChains() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
if m.router != nil {
|
if m.family4 != nil {
|
||||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
if err := m.family4.acceptExternalChainsRules(true); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
if err := m.family6.acceptExternalChainsRules(true); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -187,12 +159,8 @@ func (m *Manager) initFirewall() (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := m.router.init(workTable); err != nil {
|
if err := m.family4.init(workTable); err != nil {
|
||||||
return fmt.Errorf("router init: %w", err)
|
return fmt.Errorf("family init: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.aclManager.init(workTable); err != nil {
|
|
||||||
return fmt.Errorf("acl manager init: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
@@ -220,7 +188,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.router.mtu,
|
MTU: m.family4.mtu,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -235,12 +203,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|||||||
|
|
||||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||||
func (m *Manager) rollbackInit() {
|
func (m *Manager) rollbackInit() {
|
||||||
if err := m.router.Reset(); err != nil {
|
if err := m.family4.Reset(); err != nil {
|
||||||
log.Warnf("rollback router: %v", err)
|
log.Warnf("rollback family: %v", err)
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.router6.Reset(); err != nil {
|
if err := m.family6.Reset(); err != nil {
|
||||||
log.Warnf("rollback v6 router: %v", err)
|
log.Warnf("rollback v6 family: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := m.cleanupNetbirdTables(); err != nil {
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
@@ -251,118 +219,82 @@ func (m *Manager) rollbackInit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddFilterRule installs a packet-filtering rule.
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// Destination semantics: zero Network → input chain (peer ACL);
|
||||||
// rule ID as comment for the rule
|
// set Network → forward chain (route ACL).
|
||||||
func (m *Manager) AddPeerFiltering(
|
//
|
||||||
id []byte,
|
// Sources are a single address family; the rule is dispatched to the
|
||||||
ip net.IP,
|
// matching per-family backend.
|
||||||
proto firewall.Protocol,
|
func (m *Manager) AddFilterRule(
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination firewall.Network,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort, dPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
if len(sources) == 0 {
|
||||||
|
return nil, firewall.ErrNoSources
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if isIPv6RouteRule(sources, destination) {
|
fam := m.family4
|
||||||
|
if isIPv6Rule(sources, destination) {
|
||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
fam = m.family6
|
||||||
}
|
}
|
||||||
|
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeleteFilterRule removes a filtering rule. The owning family is found
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
// by id in the in-memory filter maps, which are the only tracking for
|
||||||
|
// filter rules. family.DeleteFilterRule is idempotent when the id is
|
||||||
|
// absent.
|
||||||
|
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule, false)
|
||||||
return m.aclManager6.DeletePeerRule(rule)
|
|
||||||
}
|
|
||||||
return m.aclManager.DeletePeerRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIPv6Rule(rule firewall.Rule) bool {
|
|
||||||
r, ok := rule.(*Rule)
|
|
||||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
|
||||||
// For static routes, the destination prefix determines the family. For dynamic
|
|
||||||
// routes (DomainSet), the sources determine the family since management
|
|
||||||
// duplicates dynamic rules per family.
|
|
||||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
|
||||||
if destination.IsPrefix() {
|
|
||||||
return destination.Prefix.Addr().Is6()
|
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
|
||||||
// router; the cached maps are normally authoritative, so the kernel is only
|
|
||||||
// consulted when neither map knows about the rule.
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
id := rule.ID()
|
|
||||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return r.DeleteRouteRule(rule)
|
return fam.DeleteFilterRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// routerForRuleID picks the router holding the rule with the given id, using
|
// familyForRuleID picks the family holding the rule with the given id, using
|
||||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
// the supplied lookup. With refresh set, a miss in both cached maps reloads
|
||||||
// from the kernel once and re-checks before falling back to the v4 router.
|
// the NAT/DNAT rule maps from the kernel once and re-checks before falling
|
||||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
// back to the v4 family. Filter rules are tracked only in memory and have no
|
||||||
if has(m.router, id) {
|
// kernel-backed reload, so their callers pass refresh as false.
|
||||||
return m.router, nil
|
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool, refresh bool) (*family, error) {
|
||||||
}
|
if has(m.family4, id) {
|
||||||
if m.hasIPv6() && has(m.router6, id) {
|
return m.family4, nil
|
||||||
return m.router6, nil
|
|
||||||
}
|
}
|
||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return m.router, nil
|
return m.family4, nil
|
||||||
}
|
}
|
||||||
if err := m.router.refreshRulesMap(); err != nil {
|
if has(m.family6, id) {
|
||||||
|
return m.family6, nil
|
||||||
|
}
|
||||||
|
if !refresh {
|
||||||
|
return m.family4, nil
|
||||||
|
}
|
||||||
|
if err := m.family4.refreshRulesMap(); err != nil {
|
||||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||||
}
|
}
|
||||||
if err := m.router6.refreshRulesMap(); err != nil {
|
if err := m.family6.refreshRulesMap(); err != nil {
|
||||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||||
}
|
}
|
||||||
if has(m.router6, id) && !has(m.router, id) {
|
if has(m.family6, id) && !has(m.family4, id) {
|
||||||
return m.router6, nil
|
return m.family6, nil
|
||||||
}
|
}
|
||||||
return m.router, nil
|
return m.family4, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
@@ -381,10 +313,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddNatRule(pair)
|
return m.family6.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router.AddNatRule(pair); err != nil {
|
if err := m.family4.AddNatRule(pair); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,7 +328,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
// so the eventual cleanup still works.
|
// so the eventual cleanup still works.
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -412,18 +344,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.router6.RemoveNatRule(pair)
|
return m.family6.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,46 +363,13 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic.
|
|
||||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
|
||||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
|
||||||
//
|
|
||||||
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
|
||||||
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
|
||||||
// Should add passthrough rules to external chains (like the native mode router's
|
|
||||||
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
|
||||||
// The netbird table itself is fine (routing chains already exist there), but
|
|
||||||
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
|
||||||
return fmt.Errorf("create default allow rules: %w", err)
|
|
||||||
}
|
|
||||||
if m.hasIPv6() {
|
|
||||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
|
||||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -484,13 +383,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.router.Reset(); err != nil {
|
if err := m.family4.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.router6.Reset(); err != nil {
|
if err := m.family6.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,14 +429,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -551,13 +450,13 @@ func (m *Manager) Flush() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if err := m.aclManager.Flush(); err != nil {
|
if err := m.family4.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.aclManager6.Flush(); err != nil {
|
if err := m.family6.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush v6 acl: %w", err)
|
return fmt.Errorf("flush v6 family: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,9 +476,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddDNATRule(rule)
|
return m.family6.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.router.AddDNATRule(rule)
|
return m.family4.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
@@ -587,7 +486,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -608,12 +507,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -630,9 +529,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
@@ -644,9 +543,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
@@ -658,9 +557,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
@@ -672,9 +571,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -903,3 +802,14 @@ func getEstablishedExprs(register uint32) []expr.Any {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
|
||||||
|
// prefix destination the destination family decides; otherwise the
|
||||||
|
// (single-family) sources do, since management duplicates rules per
|
||||||
|
// family.
|
||||||
|
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
|
}
|
||||||
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
require.Len(t, rules, 2, "expected 2 rules")
|
require.Len(t, rules, 2, "expected 2 rules")
|
||||||
@@ -147,15 +149,12 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
||||||
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
||||||
|
|
||||||
for _, r := range rule {
|
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
||||||
err = manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
@@ -180,47 +179,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
|||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
// Add accept rule first
|
// Add accept rule first
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add accept rule")
|
require.NoError(t, err, "failed to add accept rule")
|
||||||
|
|
||||||
// Add deny rule second for the same traffic
|
// Add deny rule second for the same traffic
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||||
|
|
||||||
// Find the accept and deny rules and verify deny comes before accept
|
// Single-source rules emit a direct payload+cmp on the source IP
|
||||||
|
// (no set lookup). Match by source-IP + port + verdict instead of
|
||||||
|
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
|
||||||
|
// that this test predates.
|
||||||
|
wantSrc := ip.AsSlice()
|
||||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
hasAcceptHTTPSet := false
|
var hasSrc, hasPort80 bool
|
||||||
hasDenyHTTPSet := false
|
|
||||||
hasPort80 := false
|
|
||||||
var action string
|
var action string
|
||||||
|
|
||||||
for _, e := range rule.Exprs {
|
for _, e := range rule.Exprs {
|
||||||
// Check for set lookup
|
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
|
||||||
if lookup, ok := e.(*expr.Lookup); ok {
|
if bytes.Equal(cmp.Data, wantSrc) {
|
||||||
switch lookup.SetName {
|
hasSrc = true
|
||||||
case "accept-http":
|
|
||||||
hasAcceptHTTPSet = true
|
|
||||||
case "deny-http":
|
|
||||||
hasDenyHTTPSet = true
|
|
||||||
}
|
}
|
||||||
|
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||||
}
|
|
||||||
// Check for port 80
|
|
||||||
if cmp, ok := e.(*expr.Cmp); ok {
|
|
||||||
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
|
||||||
hasPort80 = true
|
hasPort80 = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check for verdict
|
|
||||||
if verdict, ok := e.(*expr.Verdict); ok {
|
if verdict, ok := e.(*expr.Verdict); ok {
|
||||||
switch verdict.Kind {
|
switch verdict.Kind {
|
||||||
case expr.VerdictAccept:
|
case expr.VerdictAccept:
|
||||||
@@ -231,11 +222,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
|
if !hasSrc || !hasPort80 {
|
||||||
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
|
continue
|
||||||
|
}
|
||||||
|
switch action {
|
||||||
|
case "ACCEPT":
|
||||||
|
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
|
||||||
acceptRuleIndex = i
|
acceptRuleIndex = i
|
||||||
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
|
case "DROP":
|
||||||
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
|
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
|
||||||
denyRuleIndex = i
|
denyRuleIndex = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -279,7 +274,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -361,10 +356,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := netip.MustParseAddr("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||||
@@ -437,10 +432,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := netip.MustParseAddr("fd00::2")
|
ip := netip.MustParseAddr("fd00::2")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err, "add v6 peer filtering rule")
|
require.NoError(t, err, "add v6 peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||||
@@ -550,7 +545,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
|||||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
prefixes,
|
prefixes,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
@@ -565,7 +560,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
}
|
}
|
||||||
@@ -591,9 +586,9 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
@@ -606,6 +601,73 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerMultiPortFilter(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil), "failed to reset manager state")
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
|
|
||||||
|
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80, 443}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "failed to add multi-port rule")
|
||||||
|
|
||||||
|
testClient := &nftables.Conn{}
|
||||||
|
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||||
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
|
var lookup *expr.Lookup
|
||||||
|
for _, kernelRule := range rules {
|
||||||
|
if string(kernelRule.UserData) != string(rule.ID()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, e := range kernelRule.Exprs {
|
||||||
|
if l, ok := e.(*expr.Lookup); ok {
|
||||||
|
lookup = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, lookup, "multi-port rule must match ports via a set lookup")
|
||||||
|
|
||||||
|
sets, err := testClient.GetSets(manager.family4.workTable)
|
||||||
|
require.NoError(t, err, "failed to get sets")
|
||||||
|
|
||||||
|
var portSet *nftables.Set
|
||||||
|
for _, s := range sets {
|
||||||
|
if s.Name == lookup.SetName {
|
||||||
|
portSet = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, portSet, "anonymous port set not found in kernel")
|
||||||
|
|
||||||
|
portSet.Table = manager.family4.workTable
|
||||||
|
elements, err := testClient.GetSetElements(portSet)
|
||||||
|
require.NoError(t, err, "failed to get set elements")
|
||||||
|
|
||||||
|
ports := make(map[uint16]bool)
|
||||||
|
for _, e := range elements {
|
||||||
|
require.Len(t, e.Key, 2, "port set element key should be 2 bytes")
|
||||||
|
ports[binary.BigEndian.Uint16(e.Key)] = true
|
||||||
|
}
|
||||||
|
require.True(t, ports[80], "port set should contain port 80")
|
||||||
|
require.True(t, ports[443], "port set should contain port 443")
|
||||||
|
|
||||||
|
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
||||||
|
|
||||||
|
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||||
|
require.NoError(t, err, "failed to get rules after delete")
|
||||||
|
for _, kernelRule := range rules {
|
||||||
|
require.NotEqual(t, string(rule.ID()), string(kernelRule.UserData), "rule should be removed from kernel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
// need fw manager to init both acl mgr and router for all chains to be present
|
// need fw manager to init both acl mgr and family for all chains to be present
|
||||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
rtr := manager.router
|
rtr := manager.family4
|
||||||
err = rtr.AddNatRule(testCase.InputPair)
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "pair should be inserted")
|
require.NoError(t, err, "pair should be inserted")
|
||||||
|
|
||||||
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
testRouter := &router{af: afIPv4}
|
testRouter := &family{af: afIPv4}
|
||||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
|
||||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
testingExpression = append(testingExpression, sourceExp...)
|
testingExpression = append(testingExpression, sourceExp...)
|
||||||
testingExpression = append(testingExpression, destExp...)
|
testingExpression = append(testingExpression, destExp...)
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range rtr.chains {
|
for _, chain := range rtr.chains {
|
||||||
if chain.Name == chainNameManglePrerouting {
|
if chain.Name == chainNameManglePrerouting {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||||
// Compare expressions up to the mark setting expressions
|
// Compare expressions up to the mark setting expressions
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||||
found = 1
|
found = 1
|
||||||
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
rtr := manager.router
|
rtr := manager.family4
|
||||||
|
|
||||||
// First add the NAT rule using the router's method
|
// First add the NAT rule using the family's method
|
||||||
err = rtr.AddNatRule(testCase.InputPair)
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "should add NAT rule")
|
require.NoError(t, err, "should add NAT rule")
|
||||||
|
|
||||||
// Verify the rule was added
|
// Verify the rule was added
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||||
found := false
|
found := false
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules")
|
require.NoError(t, err, "should list rules")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules after removal")
|
require.NoError(t, err, "should list rules after removal")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -200,11 +200,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router")
|
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func(r *router) {
|
defer func(r *family) {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||||
}(r)
|
}(r)
|
||||||
|
|
||||||
@@ -314,16 +313,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddFilterRule failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||||
})
|
})
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
|
||||||
rule, ok := r.rules[ruleKey.ID()]
|
require.True(t, ok, "Rule not found in filters map")
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
rule := stored.nftRule
|
||||||
|
|
||||||
t.Log("Internal rule expressions:")
|
t.Log("Internal rule expressions:")
|
||||||
for i, expr := range rule.Exprs {
|
for i, expr := range rule.Exprs {
|
||||||
@@ -339,7 +338,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
var nftRule *nftables.Rule
|
var nftRule *nftables.Rule
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if string(rule.UserData) == ruleKey.ID() {
|
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
|
||||||
nftRule = rule
|
nftRule = rule
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -367,12 +366,11 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router")
|
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -509,6 +507,41 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
|
||||||
|
// overlapping prefixes before adding them. An interval set rejects
|
||||||
|
// overlapping elements, so without the merge a batch holding a /32 already
|
||||||
|
// covered by a /24, or a duplicate address as DNS resolution can produce,
|
||||||
|
// would fail.
|
||||||
|
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err, "create work table")
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, r.Reset(), "reset family")
|
||||||
|
}()
|
||||||
|
|
||||||
|
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||||
|
set := firewall.NewPrefixSet(initial)
|
||||||
|
|
||||||
|
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
|
||||||
|
require.NoError(t, err, "create ip set")
|
||||||
|
require.NotNil(t, created)
|
||||||
|
|
||||||
|
overlapping := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
|
netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
|
}
|
||||||
|
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
@@ -518,11 +551,10 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to create v6 work table")
|
require.NoError(t, err, "Failed to create v6 work table")
|
||||||
defer deleteWorkTableIPv6()
|
defer deleteWorkTableIPv6()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router")
|
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -748,6 +780,14 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case *expr.Lookup:
|
||||||
|
// Multiple discrete ports compile to an anonymous set lookup
|
||||||
|
// rather than a chain of comparisons. The set's id and name are
|
||||||
|
// assigned dynamically, so matching the lookup is enough here;
|
||||||
|
// the set elements are verified separately.
|
||||||
|
if !port.IsRange && len(port.Values) > 1 {
|
||||||
|
portMatchFound = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if payloadFound && portMatchFound {
|
if payloadFound && portMatchFound {
|
||||||
return true
|
return true
|
||||||
@@ -861,13 +901,12 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
// Add a real rule to the kernel
|
// Add a real rule to the kernel
|
||||||
ruleKey, err := r.AddRouteFiltering(
|
ruleKey, err := r.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
@@ -878,11 +917,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
staleKey := "stale-rule-that-does-not-exist"
|
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
|
||||||
r.rules[staleKey] = &nftables.Rule{
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
@@ -902,6 +941,54 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
|
||||||
|
// rule is actually removed from the kernel on delete. The route chain is
|
||||||
|
// not refreshed by Flush, so the stored rule carries a zero handle;
|
||||||
|
// DeleteFilterRule must pull live handles itself before issuing the
|
||||||
|
// delete or the kernel rule leaks. Regression test for that path.
|
||||||
|
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
ruleKey, err := r.AddFilterRule(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
countKernelRules := func() int {
|
||||||
|
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
require.NoError(t, err)
|
||||||
|
n := 0
|
||||||
|
for _, rule := range list {
|
||||||
|
if string(rule.UserData) == string(ruleKey.ID()) {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
|
||||||
|
|
||||||
|
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||||
|
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
|
||||||
|
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
|
||||||
|
}
|
||||||
|
|
||||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
@@ -911,24 +998,27 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0
|
// Inject a stale entry with Handle=0
|
||||||
staleKey := "stale-route-rule"
|
staleKey := id.RuleID("stale-route-rule")
|
||||||
r.rules[staleKey] = &nftables.Rule{
|
staleRule := &Rule{
|
||||||
Table: r.workTable,
|
nftRule: &nftables.Rule{
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Table: r.workTable,
|
||||||
Handle: 0,
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
UserData: []byte(staleKey),
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
},
|
||||||
|
id: staleKey,
|
||||||
}
|
}
|
||||||
|
r.filters[staleKey] = staleRule
|
||||||
|
|
||||||
// DeleteRouteRule should not return an error for stale handles
|
// DeleteFilterRule should not return an error for stale handles
|
||||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
err = r.DeleteFilterRule(staleRule)
|
||||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
@@ -950,7 +1040,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
rtr := manager.router
|
rtr := manager.family4
|
||||||
|
|
||||||
// First add succeeds
|
// First add succeeds
|
||||||
err = rtr.AddNatRule(pair)
|
err = rtr.AddNatRule(pair)
|
||||||
@@ -960,11 +1050,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Corrupt the handle to simulate stale state
|
// Corrupt the handle to simulate stale state
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
|
||||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
rule.Handle = 0
|
rule.Handle = 0
|
||||||
}
|
}
|
||||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
|
||||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
rule.Handle = 0
|
rule.Handle = 0
|
||||||
}
|
}
|
||||||
@@ -979,7 +1069,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
|
|
||||||
found := 0
|
found := 0
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||||
found++
|
found++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1010,7 +1100,7 @@ func TestCalculateLastIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||||
r := &router{af: afIPv6}
|
r := &family{af: afIPv6}
|
||||||
prefixes := []netip.Prefix{
|
prefixes := []netip.Prefix{
|
||||||
netip.MustParsePrefix("fd00::/64"),
|
netip.MustParsePrefix("fd00::/64"),
|
||||||
netip.MustParsePrefix("2001:db8::1/128"),
|
netip.MustParsePrefix("2001:db8::1/128"),
|
||||||
|
|||||||
500
client/firewall/nftables/routing_linux.go
Normal file
500
client/firewall/nftables/routing_linux.go
Normal file
@@ -0,0 +1,500 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.legacyManagement {
|
||||||
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
|
r.rollbackRules(pair)
|
||||||
|
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
|
if err := r.addNatRule(pair); err != nil {
|
||||||
|
r.rollbackRules(pair)
|
||||||
|
return fmt.Errorf("add nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
r.rollbackRules(pair)
|
||||||
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
r.rollbackRules(pair)
|
||||||
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *family) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []firewall.RuleID{
|
||||||
|
pair.GenKey(firewall.ForwardingFormat),
|
||||||
|
pair.GenKey(firewall.PreroutingFormat),
|
||||||
|
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
|
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||||
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(sourceExp)
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
op := expr.CmpOpEq
|
||||||
|
if pair.Inverse {
|
||||||
|
op = expr.CmpOpNeq
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: op,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
exprs = append(exprs, getCtNewExprs()...)
|
||||||
|
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
|
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||||
|
if pair.Inverse {
|
||||||
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
SourceRegister: true,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
r.dropNetworkMatch(sourceExp)
|
||||||
|
r.dropNetworkMatch(destExp)
|
||||||
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleID),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addPostroutingRules() {
|
||||||
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
|
exprs := []expr.Any{
|
||||||
|
// Match on the first fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
},
|
||||||
|
|
||||||
|
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Second masquerade rule for traffic going out through WireGuard interface
|
||||||
|
exprs2 := []expr.Any{
|
||||||
|
// Match on the second fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Match WireGuard interface
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
func (r *family) addMSSClampingRules() error {
|
||||||
|
overhead := uint16(ipv4TCPHeaderSize)
|
||||||
|
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||||
|
overhead = ipv6TCPHeaderSize
|
||||||
|
}
|
||||||
|
if r.mtu <= overhead {
|
||||||
|
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mss := r.mtu - overhead
|
||||||
|
|
||||||
|
exprsOut := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyL4PROTO,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{unix.IPPROTO_TCP},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 13,
|
||||||
|
Len: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
DestRegister: 1,
|
||||||
|
SourceRegister: 1,
|
||||||
|
Len: 1,
|
||||||
|
Mask: []byte{0x02},
|
||||||
|
Xor: []byte{0x00},
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0x00},
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Exthdr{
|
||||||
|
DestRegister: 1,
|
||||||
|
Type: 2,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
Op: expr.ExthdrOpTcpopt,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpGt,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||||
|
},
|
||||||
|
&expr.Exthdr{
|
||||||
|
SourceRegister: 1,
|
||||||
|
Type: 2,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
Op: expr.ExthdrOpTcpopt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameMangleForward],
|
||||||
|
Exprs: exprsOut,
|
||||||
|
})
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
r.dropNetworkMatch(sourceExp)
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
|
)
|
||||||
|
|
||||||
|
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleID]; exists {
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
r.dropNetworkMatch(sourceExp)
|
||||||
|
r.dropNetworkMatch(destExp)
|
||||||
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleID),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||||
|
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||||
|
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.deleteLegacyRuleEntry(ruleID, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteLegacyRuleEntry removes one legacy forwarding rule and drops its
|
||||||
|
// ipset references. It also clears stale entries that never got a handle.
|
||||||
|
func (r *family) deleteLegacyRuleEntry(ruleID firewall.RuleID, rule *nftables.Rule) error {
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLegacyManagement returns the route manager's legacy management mode
|
||||||
|
func (r *family) GetLegacyManagement() bool {
|
||||||
|
return r.legacyManagement
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||||
|
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||||
|
r.legacyManagement = isLegacy
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||||
|
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for k, rule := range r.rules {
|
||||||
|
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.deleteLegacyRuleEntry(k, rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeNatPreroutingRules() error {
|
||||||
|
table := &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: r.af.tableFamily,
|
||||||
|
}
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
}
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from nat table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Delete rules that have our UserData suffix
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||||
|
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||||
|
|
||||||
|
rule, exists := r.rules[ruleID]
|
||||||
|
if !exists {
|
||||||
|
log.Debugf("prerouting rule %s not found", ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,21 +1,26 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule wraps an installed filter rule (peer or route). Source set
|
||||||
|
// membership is encoded in the rule's expressions; DeleteFilterRule
|
||||||
|
// recovers the set name via findSets so the refcounter can drop the
|
||||||
|
// right reference. mangleRule is set only for peer rules.
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
mangleRule *nftables.Rule
|
mangleRule *nftables.Rule
|
||||||
nftSet *nftables.Set
|
// sources is the canonical source list this rule was created for.
|
||||||
ruleID string
|
sources []netip.Prefix
|
||||||
ip net.IP
|
id manager.RuleID
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *Rule) ID() string {
|
func (r *Rule) ID() manager.RuleID {
|
||||||
return r.ruleID
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pfx(ip net.IP) []netip.Prefix {
|
||||||
|
if ip == nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
if ip.IsUnspecified() {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
a, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||||
|
}
|
||||||
|
a = a.Unmap()
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||||
|
}
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.resetState()
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.Close(stateManager)
|
|
||||||
}
|
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
if m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.AllowNetbird()
|
|
||||||
}
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
Address() wgaddr.Address
|
|
||||||
GetWGDevice() *wgdevice.Device
|
|
||||||
GetDevice() *device.FilteredDevice
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -20,14 +19,18 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -58,7 +61,10 @@ const (
|
|||||||
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||||
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||||
|
|
||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter is a deprecated alias for
|
||||||
|
// NB_FORCE_USERSPACE_FIREWALL: the userspace firewall always routes in
|
||||||
|
// userspace, so forcing one forces the other. Kept for backward
|
||||||
|
// compatibility.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||||
@@ -70,14 +76,20 @@ const (
|
|||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
// errNotSupported is returned by firewall operations that only make sense with
|
||||||
|
// a kernel firewall (kernel NAT/DNAT, eBPF) and are not implemented in
|
||||||
|
// userspace mode, where they should not be called.
|
||||||
|
var errNotSupported = errors.New("not supported with userspace firewall")
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// peerRules is the canonical list-based storage for peer ACL rules.
|
||||||
type RuleSet map[string]PeerRule
|
// Drop and accept rules live in separate slices; drop-before-accept
|
||||||
|
// ordering comes from consulting the deny slice (and its index) before
|
||||||
|
// the accept one.
|
||||||
|
type peerRules []*PeerRule
|
||||||
|
|
||||||
type RouteRules []*RouteRule
|
type routeRules []*RouteRule
|
||||||
|
|
||||||
func (r RouteRules) Sort() {
|
func (r routeRules) Sort() {
|
||||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||||
// Deny rules come first
|
// Deny rules come first
|
||||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
@@ -86,22 +98,74 @@ func (r RouteRules) Sort() {
|
|||||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return strings.Compare(a.id, b.id)
|
return strings.Compare(string(a.id), string(b.id))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerRuleSpec carries the parameters that define a peer filter rule,
|
||||||
|
// threaded together through the build path so the builders take a single
|
||||||
|
// argument instead of a long parameter list.
|
||||||
|
type peerRuleSpec struct {
|
||||||
|
mgmtID []byte
|
||||||
|
sources []netip.Prefix
|
||||||
|
ipLayer gopacket.LayerType
|
||||||
|
proto firewall.Protocol
|
||||||
|
sPort *firewall.Port
|
||||||
|
dPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iface is the network interface the userspace firewall attaches to: the
|
||||||
|
// methods of the WireGuard device it actually uses.
|
||||||
|
type Iface interface {
|
||||||
|
Name() string
|
||||||
|
Address() wgaddr.Address
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterfaceAllower opens the NetBird interface in the host firewall so it
|
||||||
|
// doesn't drop traffic the userspace firewall handles, without taking over
|
||||||
|
// packet filtering. Implementations (nftables, iptables, firewalld, the windows
|
||||||
|
// netsh rule) are selected per platform and injected into Create; Apply runs at
|
||||||
|
// creation and Close on teardown.
|
||||||
|
type InterfaceAllower interface {
|
||||||
|
Apply() error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds the dependencies and options for the userspace firewall.
|
||||||
|
type Config struct {
|
||||||
|
// IFace is the overlay interface the filter attaches to.
|
||||||
|
IFace Iface
|
||||||
|
// InterfaceAllower opens the NetBird interface in foreign kernel filter
|
||||||
|
// chains so the kernel doesn't drop traffic the userspace firewall handles.
|
||||||
|
// Nil in netstack mode, on non-Linux platforms without a backend, or when
|
||||||
|
// neither nftables nor iptables is available. firewalld trust is applied by
|
||||||
|
// the manager regardless, since firewalld owns its own chains and we cannot
|
||||||
|
// insert into them.
|
||||||
|
InterfaceAllower InterfaceAllower
|
||||||
|
// DisableServerRoutes indicates whether server routes are disabled.
|
||||||
|
DisableServerRoutes bool
|
||||||
|
FlowLogger nftypes.FlowLogger
|
||||||
|
MTU uint16
|
||||||
|
}
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules map[netip.Addr]RuleSet
|
decoders sync.Pool
|
||||||
incomingDenyRules map[netip.Addr]RuleSet
|
wgIface Iface
|
||||||
incomingRules map[netip.Addr]RuleSet
|
ifaceAllower InterfaceAllower
|
||||||
routeRules RouteRules
|
mutex sync.RWMutex
|
||||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
|
||||||
decoders sync.Pool
|
|
||||||
wgIface common.IFaceMapper
|
|
||||||
nativeFirewall firewall.Manager
|
|
||||||
|
|
||||||
mutex sync.RWMutex
|
incomingDenyRules peerRules
|
||||||
|
incomingAcceptRules peerRules
|
||||||
|
incomingDenyIndex peerRuleIndex
|
||||||
|
incomingAcceptIndex peerRuleIndex
|
||||||
|
peerRulesMap map[nbid.RuleID]*PeerRule
|
||||||
|
|
||||||
|
routeRules routeRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
|
|
||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
@@ -183,24 +247,6 @@ func (d *decoder) decodePacket(data []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
|
||||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
|
||||||
if nativeFirewall == nil {
|
|
||||||
return nil, errors.New("native firewall is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return mgr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCreateEnv() (bool, bool, bool) {
|
func parseCreateEnv() (bool, bool, bool) {
|
||||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||||
var err error
|
var err error
|
||||||
@@ -231,7 +277,7 @@ func parseCreateEnv() (bool, bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
func Create(cfg Config) (_ *Manager, err error) {
|
||||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -254,62 +300,131 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
wgIface: cfg.IFace,
|
||||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
ifaceAllower: cfg.InterfaceAllower,
|
||||||
incomingDenyRules: make(map[netip.Addr]RuleSet),
|
|
||||||
incomingRules: make(map[netip.Addr]RuleSet),
|
|
||||||
wgIface: iface,
|
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: disableServerRoutes,
|
disableServerRoutes: cfg.DisableServerRoutes,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
flowLogger: flowLogger,
|
flowLogger: cfg.FlowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
|
||||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
mtu: mtu,
|
mtu: cfg.MTU,
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
|
// Release the allower (and its monitor) if setup fails after it was wired in.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
m.closeAllowerOnError()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if !disableMSSClamping {
|
if !disableMSSClamping {
|
||||||
m.mssClampEnabled = true
|
m.enableMSSClamping(cfg.MTU)
|
||||||
if mtu > ipv4TCPHeaderMinSize {
|
|
||||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
if mtu > ipv6TCPHeaderMinSize {
|
|
||||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(cfg.IFace); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
if disableConntrack {
|
m.setupConntrack(disableConntrack)
|
||||||
log.Info("conntrack is disabled")
|
|
||||||
} else {
|
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
|
||||||
}
|
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
if err := m.initForwarder(); err != nil {
|
if err := m.initForwarder(); err != nil {
|
||||||
log.Errorf("failed to initialize forwarder: %v", err)
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := cfg.IFace.SetFilter(m); err != nil {
|
||||||
return nil, fmt.Errorf("set filter: %w", err)
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.openHostFirewall(cfg.IFace.Name())
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// closeAllowerOnError releases the allower (and its monitor) when Create fails
|
||||||
|
// after the allower was wired in.
|
||||||
|
func (m *Manager) closeAllowerOnError() {
|
||||||
|
if m.ifaceAllower == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := m.ifaceAllower.Close(); err != nil {
|
||||||
|
log.Warnf("close interface allower after failed firewall setup: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableMSSClamping enables MSS clamping and computes the per-family clamp values.
|
||||||
|
func (m *Manager) enableMSSClamping(mtu uint16) {
|
||||||
|
m.mssClampEnabled = true
|
||||||
|
if mtu > ipv4TCPHeaderMinSize {
|
||||||
|
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||||
|
}
|
||||||
|
if mtu > ipv6TCPHeaderMinSize {
|
||||||
|
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupConntrack initializes the stateful trackers unless conntrack is disabled.
|
||||||
|
func (m *Manager) setupConntrack(disabled bool) {
|
||||||
|
if disabled {
|
||||||
|
log.Info("conntrack is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// openHostFirewall opens the NetBird interface in the kernel firewall so it
|
||||||
|
// doesn't drop traffic the userspace firewall handles. Best-effort: failures
|
||||||
|
// here shouldn't prevent the firewall from coming up.
|
||||||
|
func (m *Manager) openHostFirewall(ifaceName string) {
|
||||||
|
if m.ifaceAllower != nil {
|
||||||
|
if err := m.ifaceAllower.Apply(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// firewalld owns its own chains we can't insert into, so trust the interface
|
||||||
|
// there in addition to the allower. Netstack has no kernel interface.
|
||||||
|
if !m.netstack {
|
||||||
|
if err := firewalld.TrustInterface(ifaceName); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up the firewall manager: removes rules, closes trackers, and
|
||||||
|
// closes the interface allower.
|
||||||
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.resetState()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if m.ifaceAllower != nil {
|
||||||
|
if err := m.ifaceAllower.Close(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("close interface allower: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !m.netstack {
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("untrust interface in firewalld: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
||||||
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
||||||
// failure leaves v4 protection in place (and vice versa) so the returned
|
// failure leaves v4 protection in place (and vice versa) so the returned
|
||||||
// slice always contains whatever was successfully installed, even on error.
|
// slice always contains whatever was successfully installed, even on error.
|
||||||
// Callers must persist the slice so DisableRouting can clean partial state.
|
// Callers must persist the slice so DisableRouting can clean partial state.
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
|
func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
|
||||||
wgPrefix := iface.Address().Network
|
wgPrefix := iface.Address().Network
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
@@ -320,7 +435,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
v4Rule, err := m.addRouteFiltering(
|
v4Rule, err := m.addRouteRule(
|
||||||
nil,
|
nil,
|
||||||
sources,
|
sources,
|
||||||
firewall.Network{Prefix: wgPrefix},
|
firewall.Network{Prefix: wgPrefix},
|
||||||
@@ -336,7 +451,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
|||||||
|
|
||||||
if v6Net.IsValid() {
|
if v6Net.IsValid() {
|
||||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||||
v6Rule, err := m.addRouteFiltering(
|
v6Rule, err := m.addRouteRule(
|
||||||
nil,
|
nil,
|
||||||
sources,
|
sources,
|
||||||
firewall.Network{Prefix: v6Net},
|
firewall.Network{Prefix: v6Net},
|
||||||
@@ -357,20 +472,14 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting() error {
|
func (m *Manager) determineRouting() error {
|
||||||
var disableUspRouting, forceUserspaceRouter bool
|
var disableUspRouting bool
|
||||||
var err error
|
|
||||||
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
|
var err error
|
||||||
disableUspRouting, err = strconv.ParseBool(val)
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
|
||||||
forceUserspaceRouter, err = strconv.ParseBool(val)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
@@ -385,26 +494,11 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
case forceUserspaceRouter:
|
|
||||||
m.routingEnabled.Store(true)
|
|
||||||
m.nativeRouter.Store(false)
|
|
||||||
|
|
||||||
log.Info("userspace routing is forced")
|
|
||||||
|
|
||||||
case !m.netstack && m.nativeFirewall != nil:
|
|
||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
|
||||||
// netstack mode won't support native routing as there is no interface
|
|
||||||
|
|
||||||
m.routingEnabled.Store(true)
|
|
||||||
m.nativeRouter.Store(true)
|
|
||||||
|
|
||||||
log.Info("native routing is enabled")
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled.Store(true)
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter.Store(false)
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
@@ -470,96 +564,118 @@ func (m *Manager) IsStateful() bool {
|
|||||||
return m.stateful
|
return m.stateful
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
// userspace routed packets are always SNATed to the inbound direction
|
// userspace routed packets are always SNATed to the inbound direction
|
||||||
// TODO: implement outbound SNAT
|
// TODO: implement outbound SNAT
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// addPeerRule installs an input-chain rule that matches packets
|
||||||
//
|
// by source only. Called from AddFilterRule when the caller doesn't
|
||||||
// If comment argument is empty firewall manager should set
|
// specify a destination. Sources are expected to share one address
|
||||||
// rule ID as comment for the rule
|
// family; the family selects the ipLayer so the ICMP variant matches
|
||||||
func (m *Manager) AddPeerFiltering(
|
// what the decoder produces.
|
||||||
|
func (m *Manager) addPeerRule(
|
||||||
id []byte,
|
id []byte,
|
||||||
ip net.IP,
|
sources []netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
_ string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
// TODO: fix in upper layers
|
|
||||||
i, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
i = i.Unmap()
|
|
||||||
r := PeerRule{
|
|
||||||
id: uuid.New().String(),
|
|
||||||
mgmtId: id,
|
|
||||||
ip: i,
|
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
|
||||||
matchByIP: true,
|
|
||||||
drop: action == firewall.ActionDrop,
|
|
||||||
}
|
|
||||||
if i.Is4() {
|
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
|
||||||
r.matchByIP = false
|
|
||||||
}
|
|
||||||
|
|
||||||
r.sPort = sPort
|
|
||||||
r.dPort = dPort
|
|
||||||
|
|
||||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
var targetMap map[netip.Addr]RuleSet
|
|
||||||
if r.drop {
|
|
||||||
targetMap = m.incomingDenyRules
|
|
||||||
} else {
|
|
||||||
targetMap = m.incomingRules
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := targetMap[r.ip]; !ok {
|
|
||||||
targetMap[r.ip] = make(RuleSet)
|
|
||||||
}
|
|
||||||
targetMap[r.ip][r.id] = r
|
|
||||||
m.mutex.Unlock()
|
|
||||||
return []firewall.Rule{&r}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
|
||||||
id []byte,
|
|
||||||
sources []netip.Prefix,
|
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort, dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
// Sources are a single family; normalize v4-mapped prefixes to plain
|
||||||
|
// v4 and pick the matching IP layer. A /0 source matches any address
|
||||||
|
// of its own family only, mirroring the kernel backends.
|
||||||
|
normalized := make([]netip.Prefix, len(sources))
|
||||||
|
ipLayer := layers.LayerTypeIPv4
|
||||||
|
for i, p := range sources {
|
||||||
|
normalized[i] = firewall.UnmapPrefix(p)
|
||||||
|
if normalized[i].Addr().Is6() {
|
||||||
|
ipLayer = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
spec := peerRuleSpec{
|
||||||
|
mgmtID: id,
|
||||||
|
sources: normalized,
|
||||||
|
ipLayer: ipLayer,
|
||||||
|
proto: proto,
|
||||||
|
sPort: sPort,
|
||||||
|
dPort: dPort,
|
||||||
|
action: action,
|
||||||
|
}
|
||||||
|
return m.addOnePeerRule(spec), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) addRouteFiltering(
|
// addOnePeerRule builds and registers a single-family peer rule, or
|
||||||
|
// returns the existing rule when one with the same content key is
|
||||||
|
// already installed. The caller must hold m.mutex. The content key is
|
||||||
|
// the shared GenerateRuleID with an empty destination, so peer rules
|
||||||
|
// dedup the same way route rules and the kernel backends do; it is
|
||||||
|
// order-independent, so callers passing the same sources in any order
|
||||||
|
// dedup to one rule.
|
||||||
|
//
|
||||||
|
// There is no refcount: a content key is installed once and deleted on
|
||||||
|
// the first DeleteFilterRule for that key. The caller must therefore
|
||||||
|
// key its own tracking by the returned rule id so add and delete stay
|
||||||
|
// balanced per content key; the acl manager does this via
|
||||||
|
// peerRulesPairs.
|
||||||
|
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
|
||||||
|
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
|
||||||
|
if existing, ok := m.peerRulesMap[ruleID]; ok {
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := m.buildPeerRule(ruleID, spec)
|
||||||
|
m.registerPeerRule(rule)
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
|
||||||
|
r := &PeerRule{
|
||||||
|
id: ruleID,
|
||||||
|
mgmtId: spec.mgmtID,
|
||||||
|
sources: spec.sources,
|
||||||
|
action: spec.action,
|
||||||
|
srcPort: spec.sPort,
|
||||||
|
dstPort: spec.dPort,
|
||||||
|
}
|
||||||
|
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
|
||||||
|
for _, p := range spec.sources {
|
||||||
|
if p.Bits() == p.Addr().BitLen() {
|
||||||
|
r.sourceAddrs[p.Addr()] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerPeerRule records a freshly built peer rule in the matching
|
||||||
|
// slice, index, and dedup map. The caller must hold m.mutex.
|
||||||
|
func (m *Manager) registerPeerRule(r *PeerRule) {
|
||||||
|
if r.action == firewall.ActionDrop {
|
||||||
|
m.incomingDenyRules = append(m.incomingDenyRules, r)
|
||||||
|
m.incomingDenyIndex.add(r)
|
||||||
|
} else {
|
||||||
|
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
|
||||||
|
m.incomingAcceptIndex.add(r)
|
||||||
|
}
|
||||||
|
m.peerRulesMap[r.id] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFilterRule is the unified entry point for both peer (input chain)
|
||||||
|
// and route (forward chain) filtering rules. The destination
|
||||||
|
// distinguishes the two semantics: a zero Network installs an
|
||||||
|
// input-side rule that matches by source only; a set Network installs
|
||||||
|
// a forward-side rule that also matches the destination.
|
||||||
|
func (m *Manager) AddFilterRule(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination firewall.Network,
|
destination firewall.Network,
|
||||||
@@ -567,19 +683,49 @@ func (m *Manager) addRouteFiltering(
|
|||||||
sPort, dPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if len(sources) == 0 {
|
||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return nil, firewall.ErrNoSources
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
if destination.IsZero() {
|
||||||
|
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
|
||||||
|
// is used to route to the correct internal path.
|
||||||
|
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if r, ok := rule.(*PeerRule); ok {
|
||||||
|
return m.deletePeerRuleLocked(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anything else is a route rule (matched on the forward path).
|
||||||
|
return m.deleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addRouteRule(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
|
||||||
return existingRule, nil
|
return existingRule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
id: ruleID,
|
||||||
id: string(ruleKey),
|
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -594,78 +740,58 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
m.routeRulesMap[ruleKey] = &rule
|
m.routeRulesMap[ruleID] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
return m.deleteRouteRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
ruleID := rule.ID()
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
}
|
}
|
||||||
|
m.routeRules = trimmed
|
||||||
ruleKey := nbid.RuleID(rule.ID())
|
delete(m.routeRulesMap, ruleID)
|
||||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
|
||||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
|
||||||
return r.id == string(ruleKey)
|
|
||||||
})
|
|
||||||
if idx < 0 {
|
|
||||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
|
||||||
delete(m.routeRulesMap, ruleKey)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// deletePeerRuleLocked removes a peer rule from the matching slice,
|
||||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
// index, and dedup map. The caller must hold m.mutex.
|
||||||
m.mutex.Lock()
|
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
|
||||||
defer m.mutex.Unlock()
|
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
|
||||||
|
if r.action == firewall.ActionDrop {
|
||||||
|
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
|
||||||
|
}
|
||||||
|
|
||||||
r, ok := rule.(*PeerRule)
|
trimmed, stored, ok := removeRuleByID(*target, r.id)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sourceMap map[netip.Addr]RuleSet
|
|
||||||
if r.drop {
|
|
||||||
sourceMap = m.incomingDenyRules
|
|
||||||
} else {
|
|
||||||
sourceMap = m.incomingRules
|
|
||||||
}
|
|
||||||
|
|
||||||
if ruleset, ok := sourceMap[r.ip]; ok {
|
|
||||||
if _, exists := ruleset[r.id]; !exists {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(ruleset, r.id)
|
|
||||||
if len(ruleset) == 0 {
|
|
||||||
delete(sourceMap, r.ip)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
|
*target = trimmed
|
||||||
|
index.remove(stored)
|
||||||
|
delete(m.peerRulesMap, r.id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
// removeRuleByID removes the first rule whose id matches ruleID from
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
// rules, preserving order. It returns the trimmed slice, the removed
|
||||||
if m.nativeFirewall == nil {
|
// rule, and whether a match was found.
|
||||||
return nil
|
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
|
||||||
|
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
|
||||||
|
var removed T
|
||||||
|
if idx < 0 {
|
||||||
|
return rules, removed, false
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
removed = rules[idx]
|
||||||
|
return slices.Delete(rules, idx, idx+1), removed, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement is a no-op for the userspace firewall: it only matters
|
||||||
|
// when an old management server can't send route firewall rules, which the
|
||||||
|
// userspace router doesn't rely on.
|
||||||
|
func (m *Manager) SetLegacyManagement(bool) error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
@@ -674,11 +800,14 @@ func (m *Manager) Flush() error { return nil }
|
|||||||
// resetState clears all firewall rules and closes connection trackers.
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
// Must be called with m.mutex held.
|
// Must be called with m.mutex held.
|
||||||
func (m *Manager) resetState() {
|
func (m *Manager) resetState() {
|
||||||
clear(m.outgoingRules)
|
m.incomingDenyRules = m.incomingDenyRules[:0]
|
||||||
clear(m.incomingDenyRules)
|
m.incomingAcceptRules = m.incomingAcceptRules[:0]
|
||||||
clear(m.incomingRules)
|
m.incomingDenyIndex.reset()
|
||||||
|
m.incomingAcceptIndex.reset()
|
||||||
|
clear(m.peerRulesMap)
|
||||||
clear(m.routeRulesMap)
|
clear(m.routeRulesMap)
|
||||||
m.routeRules = m.routeRules[:0]
|
m.routeRules = m.routeRules[:0]
|
||||||
|
m.blockRules = nil
|
||||||
m.udpHookOut.Store(nil)
|
m.udpHookOut.Store(nil)
|
||||||
m.tcpHookOut.Store(nil)
|
m.tcpHookOut.Store(nil)
|
||||||
|
|
||||||
@@ -708,21 +837,15 @@ func (m *Manager) resetState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack is not supported by the userspace firewall: eBPF isn't
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
// used in userspace mode, so this should never be called.
|
||||||
if m.nativeFirewall == nil {
|
func (m *Manager) SetupEBPFProxyNoTrack(uint16, uint16) error {
|
||||||
return nil
|
return errNotSupported
|
||||||
}
|
|
||||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
|
||||||
return m.nativeFirewall.UpdateSet(set, prefixes)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -820,11 +943,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
|||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
return src, dst
|
return src.Unmap(), dst.Unmap()
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
return src, dst
|
return src.Unmap(), dst.Unmap()
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, netip.Addr{}
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
@@ -1404,20 +1527,12 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
|
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
|
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
|
|
||||||
return mgmtId, filter
|
|
||||||
}
|
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
|
|
||||||
return mgmtId, filter
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1438,39 +1553,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
|
||||||
return rule.mgmtId, rule.drop, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch payloadLayer {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
|
||||||
return rule.mgmtId, rule.drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
|
||||||
return rule.mgmtId, rule.drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
|
||||||
return rule.mgmtId, rule.drop, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
@@ -1547,10 +1629,13 @@ func (m *Manager) EnableRouting() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rules, err := m.blockInvalidRouted(m.wgIface)
|
rules, err := m.blockInvalidRouted(m.wgIface)
|
||||||
// Persist whatever was installed even on partial failure, so DisableRouting
|
|
||||||
// can clean it up later.
|
|
||||||
m.blockRules = rules
|
m.blockRules = rules
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Roll back so forwarding can't stay active without the full set of
|
||||||
|
// block rules.
|
||||||
|
if derr := m.disableRouting(); derr != nil {
|
||||||
|
log.Warnf("roll back routing after block rule failure: %v", derr)
|
||||||
|
}
|
||||||
return fmt.Errorf("block invalid routed: %w", err)
|
return fmt.Errorf("block invalid routed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1561,6 +1646,10 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.disableRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) disableRouting() error {
|
||||||
fwder := m.forwarder.Load()
|
fwder := m.forwarder.Load()
|
||||||
if fwder == nil {
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
ip,
|
pfx(ip), fw.Network{},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept,
|
fw.ActionAccept)
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
net.ParseIP("0.0.0.0"),
|
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
fw.ActionDrop,
|
fw.ActionDrop)
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -168,9 +164,12 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -208,9 +207,12 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
|
|
||||||
for _, count := range connCounts {
|
for _, count := range connCounts {
|
||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -251,9 +253,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -409,9 +414,12 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -536,9 +544,12 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -546,7 +557,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -619,9 +630,12 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -629,7 +643,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -730,16 +744,19 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -810,15 +827,18 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -931,7 +951,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
dst := fw.Network{Prefix: r.dest}
|
dst := fw.Network{Prefix: r.dest}
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1014,9 +1034,11 @@ func BenchmarkMSSClamping(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -1079,9 +1101,11 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -1134,9 +1158,11 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -496,40 +496,32 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.ruleAction == fw.ActionDrop {
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
// add general accept rule for the same IP to test drop rule precedence
|
// add general accept rule for the same IP to test drop rule precedence
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||||
fw.ProtocolALL,
|
fw.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
fw.ActionAccept,
|
fw.ActionAccept)
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotNil(t, rules)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range rules {
|
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction,
|
tc.ruleAction)
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotNil(t, rules)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range rules {
|
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
@@ -557,7 +549,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -652,14 +644,24 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
shouldBeBlocked: false,
|
shouldBeBlocked: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
name: "IPv6: v4 wildcard ICMP rule does not match ICMPv6",
|
||||||
srcIP: "fd00::1",
|
srcIP: "fd00::1",
|
||||||
dstIP: "fd00::100",
|
dstIP: "fd00::100",
|
||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
ruleIP: "0.0.0.0",
|
ruleIP: "0.0.0.0",
|
||||||
ruleProto: fw.ProtocolICMP,
|
ruleProto: fw.ProtocolICMP,
|
||||||
ruleAction: fw.ActionAccept,
|
ruleAction: fw.ActionAccept,
|
||||||
shouldBeBlocked: false,
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4: v6 wildcard ICMP rule does not match ICMPv4",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "::",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -672,22 +674,18 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.ruleAction == fw.ActionDrop {
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range rules {
|
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotNil(t, rules)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range rules {
|
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||||
require.NoError(t, manager.DeletePeerRule(rule))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
@@ -800,7 +798,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
@@ -1405,7 +1403,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.rule.action == fw.ActionDrop {
|
if tc.rule.action == fw.ActionDrop {
|
||||||
// add general accept rule to test drop rule
|
// add general accept rule to test drop rule
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
@@ -1415,13 +1413,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
fw.ActionAccept,
|
fw.ActionAccept,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rule)
|
require.NotEmpty(t, rule)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
@@ -1431,10 +1429,10 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
tc.rule.action,
|
tc.rule.action,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rule)
|
require.NotEmpty(t, rule)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
@@ -1602,9 +1600,9 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
var rules []fw.Rule
|
var addedRules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
@@ -1615,12 +1613,12 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rule)
|
require.NotNil(t, rule)
|
||||||
rules = append(rules, rule)
|
addedRules = append(addedRules, rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range rules {
|
for _, rule := range addedRules {
|
||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1646,7 +1644,7 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1655,7 +1653,7 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
// Add rule that uses the set (initially empty)
|
// Add rule that uses the set (initially empty)
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -1689,7 +1687,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
|||||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||||
|
|
||||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||||
_, err := manager.AddRouteFiltering(
|
_, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
fw.Network{Prefix: v6Dst},
|
fw.Network{Prefix: v6Dst},
|
||||||
@@ -1700,7 +1698,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
|||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
// Add rule first time
|
// Add rule first time
|
||||||
rule1, err := manager.AddRouteFiltering(
|
rule1, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
|||||||
require.NotNil(t, rule1)
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
// Add the same rule again
|
// Add the same rule again
|
||||||
rule2, err := manager.AddRouteFiltering(
|
rule2, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
|||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
// Add first rule
|
// Add first rule
|
||||||
rule1, err := manager.AddRouteFiltering(
|
rule1, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add different rule (different destination)
|
// Add different rule (different destination)
|
||||||
rule2, err := manager.AddRouteFiltering(
|
rule2, err := manager.AddFilterRule(
|
||||||
[]byte("policy-2"),
|
[]byte("policy-2"),
|
||||||
sources,
|
sources,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
rule1, err := manager.AddRouteFiltering(
|
rule1, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
require.True(t, pass, "Traffic should pass with rule in place")
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
// Re-add same rule (simulates network map update)
|
// Re-add same rule (simulates network map update)
|
||||||
rule2, err := manager.AddRouteFiltering(
|
rule2, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
// would remove the only matching rule and cause a traffic gap.
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
if rule1.ID() != rule2.ID() {
|
if rule1.ID() != rule2.ID() {
|
||||||
err = manager.DeleteRouteRule(rule1)
|
err = manager.DeleteFilterRule(rule1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,6 +156,59 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
"Traffic should still pass after rule update - no gap should occur")
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
|
||||||
|
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
|
||||||
|
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
|
||||||
|
// denied. The v4-only soft-skip path is covered by
|
||||||
|
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
|
||||||
|
func TestBlockInvalidRoutedDualStack(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
IPv6: wgNet6.Addr(),
|
||||||
|
IPv6Net: wgNet6,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
rules, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
|
||||||
|
|
||||||
|
// v6 routed traffic to the WG prefix must be denied.
|
||||||
|
srcIP := netip.MustParseAddr("2001:db8::1")
|
||||||
|
dstIP := netip.MustParseAddr("fd00:1234::50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
// returns the same rule without duplicating.
|
// returns the same rule without duplicating.
|
||||||
@@ -182,7 +235,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -245,7 +298,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -274,7 +327,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Simulate 5 network map updates with the same route rule
|
// Simulate 5 network map updates with the same route rule
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -304,7 +357,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
// Add same rule twice
|
// Add same rule twice
|
||||||
rule1, err := manager.AddRouteFiltering(
|
rule1, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -315,7 +368,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rule2, err := manager.AddRouteFiltering(
|
rule2, err := manager.AddFilterRule(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -329,7 +382,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
// Delete using first reference
|
// Delete using first reference
|
||||||
err = manager.DeleteRouteRule(rule1)
|
err = manager.DeleteFilterRule(rule1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify traffic no longer passes
|
// Verify traffic no longer passes
|
||||||
@@ -364,7 +417,7 @@ func setupTestManager(t *testing.T) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.EnableRouting())
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
|||||||
@@ -78,18 +78,19 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||||
|
|
||||||
if m == nil {
|
if m == nil {
|
||||||
t.Error("Manager is nil")
|
t.Error("Manager is nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
func TestManagerAddFilterRule(t *testing.T) {
|
||||||
isSetFilterCalled := false
|
isSetFilterCalled := false
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error {
|
SetFilterFunc: func(device.PacketFilter) error {
|
||||||
@@ -98,18 +99,19 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -131,74 +133,47 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||||
|
|
||||||
ip := netip.MustParseAddr("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check rules exist in appropriate maps
|
peerRule, ok := rule2.(*PeerRule)
|
||||||
for _, r := range rule2 {
|
require.True(t, ok, "rule should be a peer rule")
|
||||||
peerRule, ok := r.(*PeerRule)
|
|
||||||
if !ok {
|
inMap := func() bool {
|
||||||
t.Errorf("rule should be a PeerRule")
|
if peerRule.action == fw.ActionDrop {
|
||||||
continue
|
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
|
||||||
}
|
|
||||||
// Check if rule exists in deny or allow maps based on action
|
|
||||||
var found bool
|
|
||||||
if peerRule.drop {
|
|
||||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
|
||||||
} else {
|
|
||||||
_, found = m.incomingRules[ip][r.ID()]
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Errorf("rule2 is not in the expected rules map")
|
|
||||||
}
|
}
|
||||||
|
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
require.True(t, inMap(), "rule2 should be in the expected rules list")
|
||||||
err = m.DeletePeerRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check rules are removed from appropriate maps
|
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
|
||||||
for _, r := range rule2 {
|
|
||||||
peerRule, ok := r.(*PeerRule)
|
require.False(t, inMap(), "rule2 should be removed from the rules list")
|
||||||
if !ok {
|
|
||||||
t.Errorf("rule should be a PeerRule")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Check if rule is removed from deny or allow maps based on action
|
|
||||||
var found bool
|
|
||||||
if peerRule.drop {
|
|
||||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
|
||||||
} else {
|
|
||||||
_, found = m.incomingRules[ip][r.ID()]
|
|
||||||
}
|
|
||||||
if found {
|
|
||||||
t.Errorf("rule2 should be removed from the rules map")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetUDPPacketHook(t *testing.T) {
|
func TestSetUDPPacketHook(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -220,9 +195,11 @@ func TestSetUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetTCPPacketHook(t *testing.T) {
|
func TestSetTCPPacketHook(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -250,7 +227,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -260,36 +237,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
|||||||
addr := netip.MustParseAddr("192.168.1.1")
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
// Add multiple deny rules for different ports
|
// Add multiple deny rules for different ports
|
||||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
// Delete the first deny rule
|
// Delete the first deny rule
|
||||||
err = m.DeletePeerRule(rule1[0])
|
err = m.DeleteFilterRule(rule1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount = len(m.incomingDenyRules[addr])
|
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
// Delete the second deny rule
|
// Delete the second deny rule
|
||||||
err = m.DeletePeerRule(rule2[0])
|
err = m.DeleteFilterRule(rule2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
_, exists := m.incomingDenyRules[addr]
|
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
require.False(t, exists, "Deny rules should be cleaned up when empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
@@ -299,7 +274,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -311,27 +286,21 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
|||||||
// Simulate 10 network map updates: add rule, delete old, add new
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
// Add a deny rule
|
// Add a deny rule
|
||||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add an allow rule
|
// Add an allow rule
|
||||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Delete them (simulating ACL manager cleanup)
|
// Delete them (simulating ACL manager cleanup)
|
||||||
for _, r := range rules {
|
require.NoError(t, m.DeleteFilterRule(rules))
|
||||||
require.NoError(t, m.DeletePeerRule(r))
|
require.NoError(t, m.DeleteFilterRule(allowRules))
|
||||||
}
|
|
||||||
for _, r := range allowRules {
|
|
||||||
require.NoError(t, m.DeletePeerRule(r))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||||
allowCount := len(m.incomingRules[addr])
|
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
@@ -345,7 +314,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -354,41 +323,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
// Add allow rule for port 80
|
// Add allow rule for port 80
|
||||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add deny rule for port 22
|
// Add deny rule for port 22
|
||||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
addr := netip.MustParseAddr("192.168.1.1")
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
allowCount := len(m.incomingRules[addr])
|
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||||
denyCount := len(m.incomingDenyRules[addr])
|
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
// Delete allow rule should not affect deny rule
|
// Delete allow rule should not affect deny rule
|
||||||
err = m.DeletePeerRule(allowRule[0])
|
err = m.DeleteFilterRule(allowRule)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
// Delete deny rule
|
// Delete deny rule
|
||||||
err = m.DeletePeerRule(denyRule[0])
|
err = m.DeleteFilterRule(denyRule)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
_, denyExists := m.incomingDenyRules[addr]
|
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||||
_, allowExists := m.incomingRules[addr]
|
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.False(t, denyExists, "Deny rules should be empty")
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
@@ -400,7 +367,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -411,7 +378,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -423,7 +390,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||||
t.Errorf("rules are not empty")
|
t.Errorf("rules are not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -439,7 +406,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -449,7 +416,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -502,7 +469,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
manager, err := Create(Config{IFace: iface, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -519,9 +486,11 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
@@ -606,7 +575,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -621,7 +590,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -631,9 +600,11 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, nbiface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
@@ -845,7 +816,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -858,7 +829,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -931,7 +902,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -939,7 +910,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
|
|
||||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -1051,7 +1022,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: 1280})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1243,7 +1214,7 @@ func TestShouldForward(t *testing.T) {
|
|||||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1358,7 +1329,7 @@ func TestShouldForward(t *testing.T) {
|
|||||||
|
|
||||||
// Re-create manager to pick up the new address with IPv6
|
// Re-create manager to pick up the new address with IPv6
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
manager, err = Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
v6Cases := []struct {
|
v6Cases := []struct {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -20,9 +21,9 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,6 +34,12 @@ const (
|
|||||||
iosMaxInFlight = 256
|
iosMaxInFlight = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IFace provides the WireGuard device and overlay addresses the forwarder needs.
|
||||||
|
type IFace interface {
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
Address() wgaddr.Address
|
||||||
|
}
|
||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
@@ -51,7 +58,7 @@ type Forwarder struct {
|
|||||||
pingSemaphore chan struct{}
|
pingSemaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
func New(iface IFace, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
ipv4.NewProtocol,
|
ipv4.NewProtocol,
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type action string
|
type action string
|
||||||
@@ -20,35 +19,20 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
// WindowsInterfaceAllower opens the NetBird interface in the Windows firewall
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
// via netsh advfirewall rules. It implements InterfaceAllower for the userspace
|
||||||
m.mutex.Lock()
|
// firewall on Windows.
|
||||||
defer m.mutex.Unlock()
|
type WindowsInterfaceAllower struct {
|
||||||
|
iface Iface
|
||||||
m.resetState()
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
if isFirewallRuleActive(firewallRuleName) {
|
|
||||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// NewWindowsInterfaceAllower builds the Windows netsh-based interface allower.
|
||||||
func (m *Manager) AllowNetbird() error {
|
func NewWindowsInterfaceAllower(iface Iface) *WindowsInterfaceAllower {
|
||||||
|
return &WindowsInterfaceAllower{iface: iface}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply adds inbound-allow netsh rules for the interface's addresses.
|
||||||
|
func (a *WindowsInterfaceAllower) Apply() error {
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -60,13 +44,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
"enable=yes",
|
"enable=yes",
|
||||||
"action=allow",
|
"action=allow",
|
||||||
"profile=any",
|
"profile=any",
|
||||||
"localip="+m.wgIface.Address().IP.String(),
|
"localip="+a.iface.Address().IP.String(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
if v6 := a.iface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||||
addRule,
|
addRule,
|
||||||
"dir=in",
|
"dir=in",
|
||||||
@@ -82,8 +66,27 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
// Close removes the netsh rules added by Apply.
|
||||||
|
func (a *WindowsInterfaceAllower) Close() error {
|
||||||
|
if !isWindowsFirewallReachable() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if isFirewallRuleActive(firewallRuleName) {
|
||||||
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||||
|
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||||
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||||
if action == addRule {
|
if action == addRule {
|
||||||
args = append(args, extraArgs...)
|
args = append(args, extraArgs...)
|
||||||
@@ -7,8 +7,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||||
@@ -60,7 +58,7 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||||
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
func (m *localIPManager) UpdateLocalIPs(iface Iface) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = fmt.Errorf("panic: %v", r)
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
|||||||
@@ -487,19 +487,13 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (m *Manager) AddDNATRule(firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
if m.nativeFirewall == nil {
|
return nil, errNotSupported
|
||||||
return nil, errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes outbound DNAT rule.
|
// DeleteDNATRule deletes outbound DNAT rule.
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteDNATRule(firewall.Rule) error {
|
||||||
if m.nativeFirewall == nil {
|
return errNotSupported
|
||||||
return errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addPortRedirection adds a port redirection rule.
|
// addPortRedirection adds a port redirection rule.
|
||||||
@@ -521,7 +515,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
@@ -567,20 +560,16 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT delegates to the native firewall if available.
|
// AddOutputDNAT is not supported by the userspace firewall: it backs kernel DNS
|
||||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
// redirection, but userspace DNS is served in-process on the gVisor netstack, so
|
||||||
if m.nativeFirewall == nil {
|
// this should never be called.
|
||||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
func (m *Manager) AddOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
|
||||||
}
|
return errNotSupported
|
||||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
// RemoveOutputDNAT is a no-op for the userspace firewall (see AddOutputDNAT).
|
||||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) RemoveOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
return nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||||
|
|||||||
@@ -64,9 +64,11 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -124,9 +126,11 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
|
|
||||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -196,9 +200,11 @@ func BenchmarkDNATScaling(b *testing.B) {
|
|||||||
|
|
||||||
for _, count := range mappingCounts {
|
for _, count := range mappingCounts {
|
||||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -308,9 +314,11 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
|||||||
|
|
||||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -481,9 +489,11 @@ func BenchmarkPortDNAT(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ import (
|
|||||||
|
|
||||||
// TestPortDNATBasic tests basic port DNAT functionality
|
// TestPortDNATBasic tests basic port DNAT functionality
|
||||||
func TestPortDNATBasic(t *testing.T) {
|
func TestPortDNATBasic(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -49,9 +51,11 @@ func TestPortDNATBasic(t *testing.T) {
|
|||||||
|
|
||||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||||
func TestPortDNATMultipleRules(t *testing.T) {
|
func TestPortDNATMultipleRules(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ import (
|
|||||||
|
|
||||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -104,9 +106,11 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|||||||
|
|
||||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||||
func TestDNATMappingManagement(t *testing.T) {
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -152,9 +156,11 @@ func TestDNATMappingManagement(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInboundPortDNAT(t *testing.T) {
|
func TestInboundPortDNAT(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -202,9 +208,11 @@ func TestInboundPortDNAT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInboundPortDNATNegative(t *testing.T) {
|
func TestInboundPortDNATNegative(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(Config{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
IFace: &IFaceMock{
|
||||||
}, false, flowLogger, iface.DefaultMTU)
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
333
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
333
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
//go:build uspbench
|
||||||
|
|
||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
|
||||||
|
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
|
||||||
|
// rules, each with K source peers in its set.
|
||||||
|
//
|
||||||
|
// With the reverse-source index, miss cost is independent of M and
|
||||||
|
// hit cost grows only with the number of rules touching a single
|
||||||
|
// srcIP, not with total rule count.
|
||||||
|
func BenchmarkPeerACLMatch(b *testing.B) {
|
||||||
|
shapes := []struct{ M, K int }{
|
||||||
|
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
|
||||||
|
}
|
||||||
|
families := []struct {
|
||||||
|
name string
|
||||||
|
v6 bool
|
||||||
|
}{{"v4", false}, {"v6", true}}
|
||||||
|
|
||||||
|
for _, fam := range families {
|
||||||
|
for _, s := range shapes {
|
||||||
|
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
|
||||||
|
runPeerACLBench(b, s.M, s.K, true, fam.v6)
|
||||||
|
})
|
||||||
|
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
|
||||||
|
runPeerACLBench(b, s.M, s.K, false, fam.v6)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
|
||||||
|
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
|
||||||
|
|
||||||
|
// Miss packets are dropped, so they always traverse the full peer
|
||||||
|
// ACL matcher (every bucket) without short-circuiting and without
|
||||||
|
// touching conntrack. Disable conntrack for the miss case so it
|
||||||
|
// measures the matcher, not established-state lookups. The hit case
|
||||||
|
// keeps conntrack on: an accepted packet reaches trackInbound, which
|
||||||
|
// needs the trackers conntrack creates.
|
||||||
|
if !hit {
|
||||||
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
|
}
|
||||||
|
|
||||||
|
bits := 32
|
||||||
|
genPkt := generatePacket
|
||||||
|
addrs := uniqueAddrs
|
||||||
|
if v6 {
|
||||||
|
bits = 128
|
||||||
|
genPkt = generatePacket6
|
||||||
|
addrs = uniqueAddrs6
|
||||||
|
}
|
||||||
|
|
||||||
|
// dstIP must be a local IP so filterInbound takes the local-traffic
|
||||||
|
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
|
||||||
|
// address the manager doesn't own would be treated as routed and
|
||||||
|
// short-circuit before the peer matcher.
|
||||||
|
dstIP := addrs(1, 2)[0]
|
||||||
|
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
|
||||||
|
if v6 {
|
||||||
|
// The local-IP manager needs a valid v4 address too; expose the v6
|
||||||
|
// dst as the interface's IPv6 so IsLocalIP recognizes it.
|
||||||
|
mockAddr = wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
IPv6: dstIP,
|
||||||
|
IPv6Net: netip.PrefixFrom(dstIP, bits),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager, err := Create(Config{
|
||||||
|
IFace: &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address { return mockAddr },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
|
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
|
||||||
|
|
||||||
|
// Generate M policies × K source peers, all distinct.
|
||||||
|
all := addrs(m*k, 1)
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
sources := make([]netip.Prefix, k)
|
||||||
|
for j, a := range all[i*k : (i+1)*k] {
|
||||||
|
sources[j] = netip.PrefixFrom(a, bits)
|
||||||
|
}
|
||||||
|
_, err := manager.AddFilterRule(
|
||||||
|
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||||
|
fw.ActionAccept)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hit: cycle through real sources, picking the matching policy's port.
|
||||||
|
// Miss: a source from a disjoint range, port 80 (matches no policy).
|
||||||
|
var pktFn func(i int) []byte
|
||||||
|
if hit {
|
||||||
|
pktFn = func(i int) []byte {
|
||||||
|
policy := i % m
|
||||||
|
src := all[policy*k+(i%k)]
|
||||||
|
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
|
||||||
|
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
miss := addrs(4096, 99)
|
||||||
|
pktFn = func(i int) []byte {
|
||||||
|
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
|
||||||
|
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-build a pool to avoid allocations dominating the measurement.
|
||||||
|
pool := make([][]byte, 1024)
|
||||||
|
for i := range pool {
|
||||||
|
pool[i] = pktFn(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm the matcher is actually exercised: a hit packet must be
|
||||||
|
// allowed and a miss packet dropped. Without this the benchmark
|
||||||
|
// could silently time the routed early-return instead.
|
||||||
|
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
|
||||||
|
"benchmark must reach the peer ACL matcher")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.filterInbound(pool[i%len(pool)], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
|
||||||
|
// the source-keyed index across realistic deployment shapes. Two
|
||||||
|
// dimensions matter: (M, K), the number of policies × peers-per-policy,
|
||||||
|
// and overlap, the fraction of peers shared between policies.
|
||||||
|
//
|
||||||
|
// The output uses ReportMetric("bytes/rule") so the cost can be
|
||||||
|
// compared across shapes directly. Total bytes = bytes/rule * M.
|
||||||
|
func BenchmarkPeerACLIndexMemory(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
M, K int
|
||||||
|
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
|
||||||
|
}{
|
||||||
|
{"M=10/K=100/disjoint", 10, 100, 0},
|
||||||
|
{"M=100/K=100/disjoint", 100, 100, 0},
|
||||||
|
{"M=100/K=1000/disjoint", 100, 1000, 0},
|
||||||
|
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
|
||||||
|
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
|
||||||
|
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
b.Run(c.name, func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mgr, err := Create(Config{
|
||||||
|
IFace: &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
},
|
||||||
|
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
|
||||||
|
|
||||||
|
runtime.GC()
|
||||||
|
var ms runtime.MemStats
|
||||||
|
runtime.ReadMemStats(&ms)
|
||||||
|
before := ms.HeapAlloc
|
||||||
|
|
||||||
|
// Drop the manager's external roots so we can isolate
|
||||||
|
// the index cost. We hold the manager itself live; the
|
||||||
|
// index is what we measure on the second pass.
|
||||||
|
mgr.incomingAcceptIndex.reset()
|
||||||
|
mgr.incomingDenyIndex.reset()
|
||||||
|
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
|
||||||
|
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
|
||||||
|
runtime.GC()
|
||||||
|
runtime.ReadMemStats(&ms)
|
||||||
|
after := ms.HeapAlloc
|
||||||
|
|
||||||
|
delta := int64(before) - int64(after)
|
||||||
|
if delta < 0 {
|
||||||
|
delta = 0
|
||||||
|
}
|
||||||
|
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
|
||||||
|
b.ReportMetric(float64(delta), "bytes/total")
|
||||||
|
|
||||||
|
require.NoError(b, mgr.Close(nil))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
|
||||||
|
b.Helper()
|
||||||
|
pool := uniqueAddrs(k+m*k, 1) // big enough universe
|
||||||
|
sharedLen := int(float64(k) * overlapFrac)
|
||||||
|
if sharedLen > k {
|
||||||
|
sharedLen = k
|
||||||
|
}
|
||||||
|
shared := pool[:sharedLen]
|
||||||
|
uniquePool := pool[sharedLen:]
|
||||||
|
|
||||||
|
for i := 0; i < m; i++ {
|
||||||
|
sources := make([]netip.Prefix, 0, k)
|
||||||
|
for _, a := range shared {
|
||||||
|
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||||
|
}
|
||||||
|
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
|
||||||
|
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
|
||||||
|
for _, a := range unique {
|
||||||
|
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||||
|
}
|
||||||
|
_, err := mgr.AddFilterRule(
|
||||||
|
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||||
|
fw.ActionAccept)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
|
||||||
|
// policy sources / dst; seed 99 puts misses in 10/8.
|
||||||
|
func uniqueAddrs(n int, seed int64) []netip.Addr {
|
||||||
|
out := make([]netip.Addr, 0, n)
|
||||||
|
seen := make(map[netip.Addr]struct{}, n)
|
||||||
|
r := rand.New(rand.NewSource(seed))
|
||||||
|
miss := seed == 99
|
||||||
|
for len(out) < n {
|
||||||
|
var b [4]byte
|
||||||
|
if miss {
|
||||||
|
b[0] = 10
|
||||||
|
b[1] = byte(r.Intn(256))
|
||||||
|
} else {
|
||||||
|
b[0] = 100
|
||||||
|
b[1] = byte(64 + r.Intn(63))
|
||||||
|
}
|
||||||
|
b[2] = byte(r.Intn(256))
|
||||||
|
b[3] = byte(1 + r.Intn(254))
|
||||||
|
a := netip.AddrFrom4(b)
|
||||||
|
if _, ok := seen[a]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[a] = struct{}{}
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
|
||||||
|
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
|
||||||
|
// disjoint from any source.
|
||||||
|
func uniqueAddrs6(n int, seed int64) []netip.Addr {
|
||||||
|
out := make([]netip.Addr, 0, n)
|
||||||
|
seen := make(map[netip.Addr]struct{}, n)
|
||||||
|
r := rand.New(rand.NewSource(seed))
|
||||||
|
miss := seed == 99
|
||||||
|
for len(out) < n {
|
||||||
|
var b [16]byte
|
||||||
|
if miss {
|
||||||
|
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
|
||||||
|
} else {
|
||||||
|
b[0] = 0xfd
|
||||||
|
}
|
||||||
|
for x := 8; x < 16; x++ {
|
||||||
|
b[x] = byte(r.Intn(256))
|
||||||
|
}
|
||||||
|
a := netip.AddrFrom16(b)
|
||||||
|
if _, ok := seen[a]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[a] = struct{}{}
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
|
||||||
|
// generatePacket for the v4 case.
|
||||||
|
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ipv6 := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: protocol,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch protocol {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
|
||||||
|
transportLayer = udp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
150
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
150
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
|
require.NoError(t, err, "create manager")
|
||||||
|
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
|
||||||
|
// the same peer rule twice does not create two backing rules. The acl
|
||||||
|
// manager keys its own cache, but the firewall backend must be
|
||||||
|
// idempotent on its own so a double-apply cannot leak rules, matching
|
||||||
|
// the route path and the kernel backends.
|
||||||
|
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
|
||||||
|
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err, "first add")
|
||||||
|
|
||||||
|
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err, "second add")
|
||||||
|
|
||||||
|
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
|
||||||
|
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
|
||||||
|
// backend's owner accounting for the same-owner case: a content key
|
||||||
|
// installed twice by the same owner registers one owner claim, so the
|
||||||
|
// first DeleteFilterRule removes the rule. Owner counting only kicks
|
||||||
|
// in for distinct management rule IDs (see the peer owner tests); the
|
||||||
|
// acl manager keys its tracking per (policy, content) and deletes once
|
||||||
|
// per key, so adds and deletes stay balanced.
|
||||||
|
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
|
||||||
|
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err, "first add")
|
||||||
|
|
||||||
|
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err, "second add")
|
||||||
|
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
|
||||||
|
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
|
||||||
|
|
||||||
|
require.NoError(t, m.DeleteFilterRule(first), "delete once")
|
||||||
|
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
|
||||||
|
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
|
||||||
|
// content hash, not a random UUID: identical inputs produce the same id
|
||||||
|
// across independent managers. A random id breaks caller-side dedup.
|
||||||
|
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
|
||||||
|
ip := net.ParseIP("10.0.0.5")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
|
||||||
|
m1 := newTestManager(t)
|
||||||
|
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m2 := newTestManager(t)
|
||||||
|
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
|
||||||
|
// differing only by port are kept separate.
|
||||||
|
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
|
||||||
|
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
|
||||||
|
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
|
||||||
|
// matching on source port and one matching on destination port for the
|
||||||
|
// same selector do not collide: the port lands in a different slot, so
|
||||||
|
// the content key must differ.
|
||||||
|
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
|
||||||
|
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
|
||||||
|
// list is rejected rather than treated as "match any". "Match any" must
|
||||||
|
// be an explicit /0, so a zeroed list can never silently widen a rule to
|
||||||
|
// every source.
|
||||||
|
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
|
||||||
|
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
|
||||||
|
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
|
||||||
|
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
|
||||||
|
}
|
||||||
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newV6TestManager(t *testing.T, localV6 string) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
|
IPv6: netip.MustParseAddr(localV6),
|
||||||
|
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
||||||
|
require.NoError(t, err, "create manager")
|
||||||
|
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
|
||||||
|
t.Helper()
|
||||||
|
ip6 := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 64,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
SrcIP: net.ParseIP(src),
|
||||||
|
DstIP: net.ParseIP(dst),
|
||||||
|
}
|
||||||
|
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
|
||||||
|
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
|
||||||
|
// rules: a matching v6 source is accepted, a non-matching one is
|
||||||
|
// denied by the default. This is the end-to-end proof that the index
|
||||||
|
// is not v4-only.
|
||||||
|
func TestPeerACL_IPv6HostRule(t *testing.T) {
|
||||||
|
m := newV6TestManager(t, "fd00::100")
|
||||||
|
|
||||||
|
src := net.ParseIP("fd00::1")
|
||||||
|
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "add v6 accept rule")
|
||||||
|
|
||||||
|
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
|
||||||
|
"v6 packet from the allowed /128 source must be accepted")
|
||||||
|
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
|
||||||
|
"v6 packet from an unlisted source must be denied by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
|
||||||
|
// right index bucket: a /128 in bySource keyed by its address, and
|
||||||
|
// coarser prefixes (including ::/0) in the nonHost slice.
|
||||||
|
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
|
||||||
|
m := newV6TestManager(t, "fd00::100")
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
|
||||||
|
host := netip.MustParseAddr("fd00::1")
|
||||||
|
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
|
||||||
|
|
||||||
|
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
|
||||||
|
|
||||||
|
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
|
||||||
|
// source prefix is normalized to v4 so a plain v4 packet matches it.
|
||||||
|
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
|
||||||
|
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
v4 := netip.MustParseAddr("192.168.1.1")
|
||||||
|
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
|
||||||
|
}
|
||||||
106
client/firewall/uspfilter/peer_family_scope_test.go
Normal file
106
client/firewall/uspfilter/peer_family_scope_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// peerACLCheck decodes the packet and runs it through the peer ACLs,
|
||||||
|
// returning the attributed management rule id and the drop verdict.
|
||||||
|
func peerACLCheck(t *testing.T, m *Manager, packet []byte) ([]byte, bool) {
|
||||||
|
t.Helper()
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
require.NoError(t, d.decodePacket(packet))
|
||||||
|
src, _ := m.extractIPs(d)
|
||||||
|
return m.peerACLsBlock(src, d, packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerACL_MultiValuePortMatchesEachListedPort guards the multi-value
|
||||||
|
// port path: a rule listing several discrete destination ports must
|
||||||
|
// match a packet to each listed port and drop one that is not listed.
|
||||||
|
// Management currently splits a multi-port policy into one rule per port
|
||||||
|
// (and the wire format carries a single port), so this list shape is not
|
||||||
|
// emitted today; the test locks correct matching in case that changes.
|
||||||
|
func TestPeerACL_MultiValuePortMatchesEachListedPort(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
src := net.ParseIP("192.168.1.1")
|
||||||
|
ports := &fw.Port{Values: []uint16{80, 443}}
|
||||||
|
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolTCP, nil, ports, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "add multi-value port rule")
|
||||||
|
|
||||||
|
for _, p := range []uint16{80, 443} {
|
||||||
|
_, blocked := peerACLCheck(t, m, createTestPacket(t, "192.168.1.1", "10.0.0.2", fw.ProtocolTCP, 12345, p))
|
||||||
|
assert.False(t, blocked, "packet to listed port %d must match the rule", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, blocked := peerACLCheck(t, m, createTestPacket(t, "192.168.1.1", "10.0.0.2", fw.ProtocolTCP, 12345, 8080))
|
||||||
|
assert.True(t, blocked, "packet to a port not in the list must not match the rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerACL_MatchAnyIsFamilyScoped verifies that a /0 source matches
|
||||||
|
// only packets of its own family: 0.0.0.0/0 must not match IPv6 packets
|
||||||
|
// and ::/0 must not match IPv4 packets, matching kernel backend
|
||||||
|
// semantics.
|
||||||
|
func TestPeerACL_MatchAnyIsFamilyScoped(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
v4Packet := createTestPacket(t, "10.0.0.1", "10.0.0.2", fw.ProtocolUDP, 12345, 53)
|
||||||
|
v6Packet := v6UDPPacket(t, "fd00::1", "fd00::100", 53)
|
||||||
|
|
||||||
|
v4Any := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
rule, err := m.AddFilterRule(nil, v4Any, fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "add v4 /0 rule")
|
||||||
|
|
||||||
|
_, blocked := peerACLCheck(t, m, v4Packet)
|
||||||
|
assert.False(t, blocked, "0.0.0.0/0 must match IPv4 packets")
|
||||||
|
_, blocked = peerACLCheck(t, m, v6Packet)
|
||||||
|
assert.True(t, blocked, "0.0.0.0/0 must not match IPv6 packets")
|
||||||
|
|
||||||
|
require.NoError(t, m.DeleteFilterRule(rule))
|
||||||
|
|
||||||
|
v6Any := []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
_, err = m.AddFilterRule(nil, v6Any, fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
|
require.NoError(t, err, "add v6 /0 rule")
|
||||||
|
|
||||||
|
_, blocked = peerACLCheck(t, m, v6Packet)
|
||||||
|
assert.False(t, blocked, "::/0 must match IPv6 packets")
|
||||||
|
_, blocked = peerACLCheck(t, m, v4Packet)
|
||||||
|
assert.True(t, blocked, "::/0 must not match IPv4 packets")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteACL_MixedFamilyZeroSourcesStayFamilySafe verifies the route
|
||||||
|
// path keeps per-prefix family matching when a single rule carries both
|
||||||
|
// 0.0.0.0/0 and ::/0 sources, as blockInvalidRouted does.
|
||||||
|
func TestRouteACL_MixedFamilyZeroSourcesStayFamilySafe(t *testing.T) {
|
||||||
|
m := newTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := m.AddFilterRule(nil, sources, fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = m.AddFilterRule(nil, sources, fw.Network{Prefix: netip.MustParsePrefix("fd00:1::/64")},
|
||||||
|
fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
v4Src := netip.MustParseAddr("192.168.1.1")
|
||||||
|
v6Src := netip.MustParseAddr("fd00::1")
|
||||||
|
|
||||||
|
_, pass := m.routeACLsPass(v4Src, netip.MustParseAddr("10.0.0.5"), 255, 0, 0)
|
||||||
|
assert.True(t, pass, "v4 source must match the v4 destination rule via 0.0.0.0/0")
|
||||||
|
_, pass = m.routeACLsPass(v6Src, netip.MustParseAddr("fd00:1::5"), 255, 0, 0)
|
||||||
|
assert.True(t, pass, "v6 source must match the v6 destination rule via ::/0")
|
||||||
|
_, pass = m.routeACLsPass(v6Src, netip.MustParseAddr("10.0.0.5"), 255, 0, 0)
|
||||||
|
assert.True(t, pass, "v6 source still passes the v4 destination rule via ::/0 in the same source list")
|
||||||
|
}
|
||||||
140
client/firewall/uspfilter/peer_index.go
Normal file
140
client/firewall/uspfilter/peer_index.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// peerRuleIndex is the source-side dispatcher consulted on the packet
|
||||||
|
// hot path. It splits rules into two buckets by the shape of their
|
||||||
|
// source list:
|
||||||
|
//
|
||||||
|
// - bySource: every source is a host prefix (/32 for v4, /128 for
|
||||||
|
// v6). Keyed by the concrete source address, so a hit guarantees
|
||||||
|
// the source filter passes and the matcher goes straight to
|
||||||
|
// proto/port checks. This is the common case for peer ACLs.
|
||||||
|
// - nonHost: any source list with a prefix coarser than a host,
|
||||||
|
// including a /0 "match any". Walked linearly with a per-rule
|
||||||
|
// Contains() check. Expected small or empty for typical peer ACLs.
|
||||||
|
//
|
||||||
|
// Maintained incrementally by add/remove, never rebuilt.
|
||||||
|
type peerRuleIndex struct {
|
||||||
|
bySource map[netip.Addr][]*PeerRule
|
||||||
|
nonHost []*PeerRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *peerRuleIndex) add(r *PeerRule) {
|
||||||
|
if hasNonHostSource(r) {
|
||||||
|
i.nonHost = append(i.nonHost, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if i.bySource == nil {
|
||||||
|
i.bySource = make(map[netip.Addr][]*PeerRule)
|
||||||
|
}
|
||||||
|
for a := range r.sourceAddrs {
|
||||||
|
i.bySource[a] = append(i.bySource[a], r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *peerRuleIndex) remove(r *PeerRule) {
|
||||||
|
if hasNonHostSource(r) {
|
||||||
|
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if i.bySource == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for a := range r.sourceAddrs {
|
||||||
|
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
|
||||||
|
if len(entries) == 0 {
|
||||||
|
delete(i.bySource, a)
|
||||||
|
} else {
|
||||||
|
i.bySource[a] = entries
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *peerRuleIndex) reset() {
|
||||||
|
i.bySource = nil
|
||||||
|
i.nonHost = i.nonHost[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// match returns the first rule matching src and the decoded packet.
|
||||||
|
// Host rules are found by direct map lookup; nonHost rules run a
|
||||||
|
// per-rule source Contains() check. Containment is family-scoped, so
|
||||||
|
// a /0 source matches every address of its own family only (0.0.0.0/0
|
||||||
|
// never matches v6 sources and ::/0 never matches v4). Within either
|
||||||
|
// bucket the matcher runs the proto/port filter.
|
||||||
|
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
|
|
||||||
|
for _, rule := range i.bySource[src] {
|
||||||
|
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||||
|
return id, drop, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, rule := range i.nonHost {
|
||||||
|
if !prefixesContain(rule.sources, src) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||||
|
return id, drop, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func eqRule(target *PeerRule) func(*PeerRule) bool {
|
||||||
|
return func(p *PeerRule) bool { return p == target }
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasNonHostSource reports whether the rule has any source prefix
|
||||||
|
// that is not a single host address. Called only at add/remove time,
|
||||||
|
// not on the packet path.
|
||||||
|
func hasNonHostSource(r *PeerRule) bool {
|
||||||
|
for _, p := range r.sources {
|
||||||
|
if p.Bits() != p.Addr().BitLen() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchProto applies the proto/port half of a rule against the
|
||||||
|
// decoded packet. Source matching is the caller's responsibility.
|
||||||
|
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
|
||||||
|
drop := rule.action == firewall.ActionDrop
|
||||||
|
if rule.protoLayer == layerTypeAll {
|
||||||
|
return rule.mgmtId, drop, true
|
||||||
|
}
|
||||||
|
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
switch payloadLayer {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
|
||||||
|
return rule.mgmtId, drop, true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
|
||||||
|
return rule.mgmtId, drop, true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return rule.mgmtId, drop, true
|
||||||
|
}
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
|
||||||
|
for _, p := range sources {
|
||||||
|
if p.Contains(src) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -10,24 +10,43 @@ import (
|
|||||||
|
|
||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id firewall.RuleID
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
ip netip.Addr
|
// sources is the canonical list of source prefixes this rule
|
||||||
ipLayer gopacket.LayerType
|
// matches against.
|
||||||
matchByIP bool
|
sources []netip.Prefix
|
||||||
protoLayer gopacket.LayerType
|
// sourceAddrs is a fast-path membership set for host-prefix
|
||||||
sPort *firewall.Port
|
// sources (/32 v4, /128 v6). Populated alongside sources;
|
||||||
dPort *firewall.Port
|
// consulted before falling back to prefix scan.
|
||||||
drop bool
|
sourceAddrs map[netip.Addr]struct{}
|
||||||
|
protoLayer gopacket.LayerType
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesSource reports whether the given source address is covered
|
||||||
|
// by this rule's source list. Prefix containment is family-scoped, so
|
||||||
|
// a /0 source matches every address of its own family only.
|
||||||
|
func (r *PeerRule) matchesSource(src netip.Addr) bool {
|
||||||
|
if _, ok := r.sourceAddrs[src]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, p := range r.sources {
|
||||||
|
if p.Contains(src) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *PeerRule) ID() string {
|
func (r *PeerRule) ID() firewall.RuleID {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id firewall.RuleID
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dstSet firewall.Set
|
dstSet firewall.Set
|
||||||
@@ -39,6 +58,6 @@ type RouteRule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *RouteRule) ID() string {
|
func (r *RouteRule) ID() firewall.RuleID {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// countRulesForAddr reports how many rules in the given slice match
|
||||||
|
// the supplied source address.
|
||||||
|
func countRulesForAddr(rules peerRules, src netip.Addr) int {
|
||||||
|
n := 0
|
||||||
|
for _, r := range rules {
|
||||||
|
if r.matchesSource(src) {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRuleByID returns true if the rules slice contains a rule with
|
||||||
|
// the given id whose source set covers src.
|
||||||
|
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
|
||||||
|
for _, r := range rules {
|
||||||
|
if r.id == id && r.matchesSource(src) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pfx converts a single net.IP into the []netip.Prefix form
|
||||||
|
// AddFilterRule expects. A nil or unspecified address becomes a /0
|
||||||
|
// ("match any") prefix in the matching family; any other address
|
||||||
|
// becomes its /32 (or /128) host prefix.
|
||||||
|
func pfx(ip net.IP) []netip.Prefix {
|
||||||
|
if ip == nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
if ip.IsUnspecified() {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||||
|
}
|
||||||
|
a, _ := netip.AddrFromSlice(ip)
|
||||||
|
a = a.Unmap()
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||||
|
}
|
||||||
@@ -285,6 +285,14 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace.SourceIP = srcIP
|
trace.SourceIP = srcIP
|
||||||
trace.DestinationIP = dstIP
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
|
// A fragment or otherwise truncated packet has no transport layer.
|
||||||
|
// The inbound datapath drops these via isValidPacket; the tracer must
|
||||||
|
// guard explicitly since every downstream stage reads d.decoded[1].
|
||||||
|
if len(d.decoded) < 2 {
|
||||||
|
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
// Determine protocol and ports
|
// Determine protocol and ports
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if !statefulMode {
|
if !statefulMode {
|
||||||
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
ip := net.ParseIP("1.1.1.1")
|
ip := net.ParseIP("1.1.1.1")
|
||||||
proto := fw.ProtocolICMP
|
proto := fw.ProtocolICMP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
ip := net.ParseIP("1.1.1.1")
|
ip := net.ParseIP("1.1.1.1")
|
||||||
proto := fw.ProtocolICMP
|
proto := fw.ProtocolICMP
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
port := &fw.Port{Values: []uint16{53}}
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux
|
//go:build !linux || !privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
|||||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
|
||||||
wgPort := 51850
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("192.168.0.56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
|
||||||
wgPort := 51851
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("fe80::56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||||
wgPort := 51852
|
wgPort := 51852
|
||||||
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51850
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51851
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||||
wgPort := 51856
|
wgPort := 51856
|
||||||
|
|||||||
190
client/internal/acl/dispatch_test.go
Normal file
190
client/internal/acl/dispatch_test.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
|
||||||
|
// invariant: the backends classify a rule as a peer rule purely by the
|
||||||
|
// absence of a destination (neither prefix nor set). A default route
|
||||||
|
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
|
||||||
|
// a route, not collapse into the peer path.
|
||||||
|
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
|
||||||
|
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
|
||||||
|
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
|
||||||
|
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
|
||||||
|
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A zero-value Network is the only peer-rule shape.
|
||||||
|
var empty fwmgr.Network
|
||||||
|
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
|
||||||
|
assert.False(t, empty.IsSet(), "zero Network must not be a set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDetermineDestinationAlwaysRoute verifies determineDestination
|
||||||
|
// never yields an empty Network for a valid route rule: every branch
|
||||||
|
// (static prefix, default route, dynamic with/without domains, with and
|
||||||
|
// without a local resolver) produces a destination that classifies as a
|
||||||
|
// route. If this regresses, a route rule would be dispatched down the
|
||||||
|
// peer path, which matches on source only.
|
||||||
|
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
|
||||||
|
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||||
|
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
rule *mgmProto.RouteFirewallRule
|
||||||
|
resolver bool
|
||||||
|
sources []netip.Prefix
|
||||||
|
}{
|
||||||
|
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
|
||||||
|
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
|
||||||
|
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
|
||||||
|
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
|
||||||
|
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
|
||||||
|
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
|
||||||
|
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, dest.IsPrefix() || dest.IsSet(),
|
||||||
|
"destination must classify as a route, got empty Network")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// countingFirewall wraps a real firewall.Manager and counts filter-rule
|
||||||
|
// add/delete calls so a test can assert how many backing rules the acl
|
||||||
|
// manager actually creates and tears down.
|
||||||
|
type countingFirewall struct {
|
||||||
|
fwmgr.Manager
|
||||||
|
mu sync.Mutex
|
||||||
|
addCalls int
|
||||||
|
dels int
|
||||||
|
ruleIDs map[fwmgr.RuleID]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// distinctRules returns the number of distinct backing rules the
|
||||||
|
// backend produced. Because the backend dedups identical content,
|
||||||
|
// repeated AddFilterRule calls for the same rule resolve to one id.
|
||||||
|
func (f *countingFirewall) distinctRules() int {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return len(f.ruleIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
||||||
|
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
if err == nil {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.addCalls++
|
||||||
|
if f.ruleIDs == nil {
|
||||||
|
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
|
||||||
|
}
|
||||||
|
if rule != nil {
|
||||||
|
f.ruleIDs[rule.ID()] = struct{}{}
|
||||||
|
}
|
||||||
|
f.mu.Unlock()
|
||||||
|
}
|
||||||
|
return rule, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||||
|
err := f.Manager.DeleteFilterRule(r)
|
||||||
|
if err == nil {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.dels++
|
||||||
|
delete(f.ruleIDs, r.ID())
|
||||||
|
f.mu.Unlock()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
fw := &countingFirewall{Manager: realFW}
|
||||||
|
cleanup := func() {
|
||||||
|
require.NoError(t, realFW.Close(nil))
|
||||||
|
ctrl.Finish()
|
||||||
|
}
|
||||||
|
return NewDefaultManager(fw), fw, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
|
||||||
|
// the backends rely on: two policies that authorize an identical flow
|
||||||
|
// (same selector and sources) collapse to a single backing firewall
|
||||||
|
// rule, and that rule survives until BOTH policies are gone. This is
|
||||||
|
// why the backend can dedup on add without refcounting on delete: the
|
||||||
|
// acl manager's pair key matches the backend's content key, so add and
|
||||||
|
// delete stay balanced per content key across full-state reapplies.
|
||||||
|
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
|
||||||
|
acl, fw, cleanup := newCountingACL(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ruleA := &mgmProto.FirewallRule{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
}
|
||||||
|
ruleB := &mgmProto.FirewallRule{
|
||||||
|
PolicyID: []byte("policy-B"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both policies present: identical content collapses to one rule.
|
||||||
|
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||||
|
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
|
||||||
|
|
||||||
|
// Drop policy A only: the shared rule is still authorized by B, so
|
||||||
|
// nothing is deleted.
|
||||||
|
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||||
|
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
|
||||||
|
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Drop policy B too: now the content key has no authorizer and the
|
||||||
|
// single backing rule is removed exactly once.
|
||||||
|
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
|
||||||
|
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
|
}
|
||||||
318
client/internal/acl/grouping_test.go
Normal file
318
client/internal/acl/grouping_test.go
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
|
||||||
|
// with identical selectors but different PolicyIDs do NOT get merged
|
||||||
|
// into one group, so each policy's sources merge under its own
|
||||||
|
// attribution id. (Identical-content groups may still dedup to one
|
||||||
|
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
|
||||||
|
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
|
||||||
|
rules := []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-B"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
|
||||||
|
// belonging to the same policy don't merge.
|
||||||
|
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
|
||||||
|
rules := []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "2001:db8::1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2, "rules of different families must produce separate groups")
|
||||||
|
|
||||||
|
var sawV4, sawV6 bool
|
||||||
|
for _, g := range groups {
|
||||||
|
require.Len(t, g.sources, 1)
|
||||||
|
if g.sources[0].Addr().Is4() {
|
||||||
|
sawV4 = true
|
||||||
|
}
|
||||||
|
if g.sources[0].Addr().Is6() {
|
||||||
|
sawV6 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, sawV4 && sawV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
|
||||||
|
// FirewallRule carrying both v4 and v6 source prefixes is split into one
|
||||||
|
// group per family. Each backend keys a rule to a single family, so a
|
||||||
|
// group whose sources span families would mismatch the other family's
|
||||||
|
// sources. mgmt normally emits one rule per family; this guards against
|
||||||
|
// a mixed-family rule slipping through.
|
||||||
|
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
|
||||||
|
srcs := [][]byte{
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
|
||||||
|
}
|
||||||
|
rules := []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
SourcePrefixes: srcs,
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
|
||||||
|
|
||||||
|
for _, g := range groups {
|
||||||
|
require.Len(t, g.sources, 2)
|
||||||
|
v6 := prefixIsV6(g.sources[0])
|
||||||
|
for _, s := range g.sources {
|
||||||
|
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
|
||||||
|
// every distinguishing field (policy, family, direction, action,
|
||||||
|
// proto, port) collapse into a single multi-source group.
|
||||||
|
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
|
||||||
|
mk := func(peerIP string) *mgmProto.FirewallRule {
|
||||||
|
return &mgmProto.FirewallRule{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: peerIP,
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 1)
|
||||||
|
require.Len(t, groups[0].sources, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
|
||||||
|
// selector key: rules differing only in port must not merge, and a
|
||||||
|
// single port must not merge with a range. A regression dropping the
|
||||||
|
// port from the key would collapse rules for different ports into one.
|
||||||
|
func TestGroupPeerRulesPortSeparates(t *testing.T) {
|
||||||
|
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
|
||||||
|
return &mgmProto.FirewallRule{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: peerIP,
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
|
||||||
|
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
|
||||||
|
})
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2, "rules on different ports must not merge")
|
||||||
|
|
||||||
|
rangeRule := &mgmProto.FirewallRule{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
|
||||||
|
}
|
||||||
|
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2, "a single port and a range must not merge")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
|
||||||
|
// new sourcePrefixes wire field is consumed and produces a
|
||||||
|
// multi-source group in one shot (no client-side merging needed).
|
||||||
|
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
|
||||||
|
srcs := [][]byte{
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||||
|
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
|
||||||
|
}
|
||||||
|
rules := []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
SourcePrefixes: srcs,
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 1)
|
||||||
|
require.Len(t, groups[0].sources, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
|
||||||
|
// and drop rules with the same selector don't merge.
|
||||||
|
func TestGroupPeerRulesActionSeparates(t *testing.T) {
|
||||||
|
rules := []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
require.NoError(t, denyErr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// failingDeleteFirewall wraps a real firewall.Manager and forces the
|
||||||
|
// next N DeleteFilterRule calls to fail. Used to verify that the acl
|
||||||
|
// manager retains rules whose deletion was rejected by the backend,
|
||||||
|
// so they get retried on the next ApplyFiltering pass instead of
|
||||||
|
// becoming orphans.
|
||||||
|
type failingDeleteFirewall struct {
|
||||||
|
fwmgr.Manager
|
||||||
|
failCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||||
|
if f.failCount > 0 {
|
||||||
|
f.failCount--
|
||||||
|
return errors.New("simulated delete failure")
|
||||||
|
}
|
||||||
|
return f.Manager.DeleteFilterRule(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
|
||||||
|
// transient DeleteFilterRule error doesn't make the acl manager
|
||||||
|
// forget about a rule. The rule must remain in peerRulesPairs so the
|
||||||
|
// next ApplyFiltering pass attempts the delete again.
|
||||||
|
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
||||||
|
|
||||||
|
fw := &failingDeleteFirewall{Manager: realFW}
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First pass: install a rule.
|
||||||
|
netmap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PolicyID: []byte("policy-A"),
|
||||||
|
PeerIP: "10.0.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(netmap1, false)
|
||||||
|
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
|
||||||
|
|
||||||
|
// Second pass: remove the rule from the map. The backend will
|
||||||
|
// fail the delete; the acl manager must retain the rule.
|
||||||
|
fw.failCount = 1
|
||||||
|
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
|
||||||
|
acl.ApplyFiltering(netmap2, false)
|
||||||
|
require.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"rule must be retained when DeleteFilterRule fails so it gets retried")
|
||||||
|
|
||||||
|
// Third pass: same map, backend no longer fails. The rule
|
||||||
|
// should now succeed in being removed.
|
||||||
|
acl.ApplyFiltering(netmap2, false)
|
||||||
|
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
|
||||||
|
}
|
||||||
@@ -5,18 +5,18 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RuleID string
|
// RuleID aliases manager.RuleID so existing nbid.RuleID references
|
||||||
|
// keep working while the canonical type lives in the firewall package.
|
||||||
|
type RuleID = manager.RuleID
|
||||||
|
|
||||||
func (r RuleID) ID() string {
|
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
|
||||||
return string(r)
|
func GenerateRuleID(
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateRouteRuleKey(
|
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination manager.Network,
|
destination manager.Network,
|
||||||
proto manager.Protocol,
|
proto manager.Protocol,
|
||||||
@@ -24,6 +24,7 @@ func GenerateRouteRuleKey(
|
|||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
action manager.Action,
|
action manager.Action,
|
||||||
) RuleID {
|
) RuleID {
|
||||||
|
sources = slices.Clone(sources)
|
||||||
manager.SortPrefixes(sources)
|
manager.SortPrefixes(sources)
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
|
|||||||
75
client/internal/acl/legacy_fallback_test.go
Normal file
75
client/internal/acl/legacy_fallback_test.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sourcesRecordingFirewall wraps a real firewall.Manager and records
|
||||||
|
// the source prefixes of every AddFilterRule call.
|
||||||
|
type sourcesRecordingFirewall struct {
|
||||||
|
fwmgr.Manager
|
||||||
|
mu sync.Mutex
|
||||||
|
sources [][]netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *sourcesRecordingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.sources = append(f.sources, sources)
|
||||||
|
f.mu.Unlock()
|
||||||
|
return f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLegacyManagementFallbackUsesMatchAnySources verifies the
|
||||||
|
// allow-all fallback for old management servers (empty FirewallRules
|
||||||
|
// without the FirewallRulesIsEmpty flag) reaches the firewall as /0
|
||||||
|
// match-any sources. The fallback rule carries PeerIP 0.0.0.0; if that
|
||||||
|
// were converted to a host prefix (0.0.0.0/32) it would match nothing
|
||||||
|
// and all peer traffic would be dropped.
|
||||||
|
func TestLegacyManagementFallbackUsesMatchAnySources(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
||||||
|
|
||||||
|
fw := &sourcesRecordingFirewall{Manager: realFW}
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Old management: no rules and no FirewallRulesIsEmpty flag.
|
||||||
|
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: false}, false)
|
||||||
|
|
||||||
|
fw.mu.Lock()
|
||||||
|
defer fw.mu.Unlock()
|
||||||
|
require.NotEmpty(t, fw.sources, "legacy fallback must install at least one allow-all rule")
|
||||||
|
for _, sources := range fw.sources {
|
||||||
|
require.NotEmpty(t, sources)
|
||||||
|
for _, p := range sources {
|
||||||
|
assert.Equal(t, 0, p.Bits(), "legacy fallback source %s must be a /0 match-any prefix", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -23,6 +21,10 @@ import (
|
|||||||
|
|
||||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
|
|
||||||
|
// ErrNoRuleReturned is returned when the firewall backend reports success
|
||||||
|
// from AddFilterRule but yields no rule to track.
|
||||||
|
var ErrNoRuleReturned = errors.New("backend returned no rule")
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
@@ -31,17 +33,46 @@ type Manager interface {
|
|||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
ipsetCounter int
|
|
||||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||||
routeRules map[id.RuleID]struct{}
|
routeRules map[id.RuleID]firewall.Rule
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerRuleGroup collapses a set of single-source FirewallRules sharing
|
||||||
|
// the same selector into one multi-source rule to push to the backend.
|
||||||
|
type peerRuleGroup struct {
|
||||||
|
direction mgmProto.RuleDirection
|
||||||
|
action mgmProto.RuleAction
|
||||||
|
protocol mgmProto.RuleProtocol
|
||||||
|
port *mgmProto.PortInfo
|
||||||
|
// legacyPort is used only when PortInfo is empty (old management).
|
||||||
|
legacyPort string
|
||||||
|
policyID []byte
|
||||||
|
sources []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// peerRuleKey is the comparable selector that decides which single-source
|
||||||
|
// rules merge into one group. Rules with an equal key collapse into one
|
||||||
|
// multi-source backend rule. PortInfo is flattened into its scalar fields
|
||||||
|
// so the key compares by value; policyID keeps policies separate so two
|
||||||
|
// policies authorizing different peers don't merge under one attribution.
|
||||||
|
type peerRuleKey struct {
|
||||||
|
v6 bool
|
||||||
|
policyID string
|
||||||
|
direction mgmProto.RuleDirection
|
||||||
|
action mgmProto.RuleAction
|
||||||
|
protocol mgmProto.RuleProtocol
|
||||||
|
legacyPort string
|
||||||
|
port uint16
|
||||||
|
rangeStart uint16
|
||||||
|
rangeEnd uint16
|
||||||
|
}
|
||||||
|
|
||||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
return &DefaultManager{
|
return &DefaultManager{
|
||||||
firewall: fm,
|
firewall: fm,
|
||||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||||
routeRules: make(map[id.RuleID]struct{}),
|
routeRules: make(map[id.RuleID]firewall.Rule),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,10 +99,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
if err := d.applyPeerACLs(networkMap); err != nil {
|
||||||
|
log.Errorf("apply peer ACLs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.firewall.Flush(); err != nil {
|
if err := d.firewall.Flush(); err != nil {
|
||||||
@@ -79,7 +112,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
|
||||||
rules := networkMap.FirewallRules
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||||
@@ -102,59 +135,167 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
// Group incoming single-source rules from management by their
|
||||||
ipsetByRuleSelectors := make(map[string]string)
|
// (direction, action, proto, port) selector and merge sources.
|
||||||
|
// One call to the firewall backend per merged rule.
|
||||||
|
// A deny we cannot decode would leave its traffic unblocked, so skip
|
||||||
|
// the whole pass and keep existing rules until the next sync.
|
||||||
|
groups, denyErr, err := groupPeerRules(rules)
|
||||||
|
if denyErr != nil {
|
||||||
|
return fmt.Errorf("decode deny rule sources: %w", denyErr)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
|
||||||
// the missing deny. Currently we accumulate errors and continue.
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, r := range rules {
|
if err != nil {
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
merr = multierror.Append(merr, err)
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
}
|
||||||
selector := d.getRuleGroupingSelector(r)
|
|
||||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
// Apply denies first. A deny that fails to install is a security
|
||||||
if !ok {
|
// failure (fail-open), so if any deny errors we roll back the
|
||||||
d.ipsetCounter++
|
// denies we already installed in this pass and bail out without
|
||||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
// installing any accept. Pre-existing rules stay untouched until
|
||||||
ipsetByRuleSelectors[selector] = ipsetName
|
// the next successful pass clears them.
|
||||||
}
|
denies, accepts := splitDenyAccept(groups)
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("install deny rules: %w", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
}
|
||||||
|
|
||||||
|
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tear down rules that disappeared from the networkmap. Any rule
|
||||||
|
// the backend refuses to delete stays in our tracking so it gets
|
||||||
|
// retried on the next ApplyFiltering. Otherwise a transient
|
||||||
|
// delete failure would leak the rule in the firewall until the
|
||||||
|
// process exits.
|
||||||
|
for pairID, rules := range d.peerRulesPairs {
|
||||||
|
if _, ok := newRulePairs[pairID]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(rulePair) > 0 {
|
var remaining []firewall.Rule
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
for _, rule := range rules {
|
||||||
newRulePairs[pairID] = rulePair
|
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||||
}
|
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
|
||||||
}
|
remaining = append(remaining, rule)
|
||||||
|
|
||||||
if merr != nil {
|
|
||||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
|
||||||
}
|
|
||||||
|
|
||||||
for pairID, rules := range d.peerRulesPairs {
|
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete(d.peerRulesPairs, pairID)
|
}
|
||||||
|
if len(remaining) > 0 {
|
||||||
|
newRulePairs[pairID] = remaining
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
d.peerRulesPairs = newRulePairs
|
d.peerRulesPairs = newRulePairs
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// installPeerGroups applies each group and records the resulting rule
|
||||||
|
// pairs in newRulePairs. With atomic set (deny rules), a single failure
|
||||||
|
// rolls back every rule installed in this call and returns, leaving the
|
||||||
|
// firewall exactly as before: denies are fail-closed and must be applied
|
||||||
|
// all-or-nothing. With atomic unset (accept rules), failures are
|
||||||
|
// accumulated and the remaining groups still install, so one malformed
|
||||||
|
// allow cannot drop every other legitimate allow in the pass.
|
||||||
|
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
|
||||||
|
var freshlyInstalled []id.RuleID
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, g := range groups {
|
||||||
|
pairID, rulePair, err := d.applyPeerGroup(g)
|
||||||
|
if err != nil {
|
||||||
|
if atomic {
|
||||||
|
d.rollbackInstalled(freshlyInstalled)
|
||||||
|
return fmt.Errorf("apply firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(rulePair) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, existed := d.peerRulesPairs[pairID]; !existed {
|
||||||
|
freshlyInstalled = append(freshlyInstalled, pairID)
|
||||||
|
}
|
||||||
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
|
newRulePairs[pairID] = rulePair
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, pairID := range pairIDs {
|
||||||
|
// Keep any rule the backend refuses to delete tracked so it is
|
||||||
|
// retried on the next ApplyFiltering instead of leaking in the
|
||||||
|
// firewall with no tracking left to remove it.
|
||||||
|
var remaining []firewall.Rule
|
||||||
|
for _, rule := range d.peerRulesPairs[pairID] {
|
||||||
|
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
|
||||||
|
remaining = append(remaining, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(remaining) > 0 {
|
||||||
|
d.peerRulesPairs[pairID] = remaining
|
||||||
|
} else {
|
||||||
|
delete(d.peerRulesPairs, pairID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := nberrors.FormatErrorOrNil(merr); err != nil {
|
||||||
|
log.Errorf("rollback peer rules: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
|
||||||
|
protocol, err := ConvertToFirewallProtocol(g.protocol)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
action, err := convertFirewallAction(g.action)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
port, err := resolveGroupPort(g)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var fwRule firewall.Rule
|
||||||
|
switch g.direction {
|
||||||
|
case mgmProto.RuleDirection_IN:
|
||||||
|
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
|
||||||
|
case mgmProto.RuleDirection_OUT:
|
||||||
|
if d.firewall.IsStateful() {
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
|
||||||
|
default:
|
||||||
|
return "", nil, errors.New("invalid direction")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
if fwRule == nil {
|
||||||
|
return "", nil, fmt.Errorf("add firewall rule: %w", ErrNoRuleReturned)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive the pair id from the backend rule, like the route path:
|
||||||
|
// the backend dedups identical content, so two policies authorizing
|
||||||
|
// the same flow resolve to the same id and a single backing rule.
|
||||||
|
return fwRule.ID(), []firewall.Rule{fwRule}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
// Apply new rules - firewall manager will return the existing rule if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule, dynamicResolver)
|
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||||
@@ -163,16 +304,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newRouteRules[id] = struct{}{}
|
newRouteRules[addedRule.ID()] = addedRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up old firewall rules
|
// Tear down old route rules; retain ones the backend refused so a
|
||||||
for id := range d.routeRules {
|
// transient failure doesn't leave orphaned rules in the firewall.
|
||||||
if _, exists := newRouteRules[id]; !exists {
|
for ruleID, rule := range d.routeRules {
|
||||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
if _, exists := newRouteRules[ruleID]; exists {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
continue
|
||||||
}
|
}
|
||||||
// implicitly deleted from the map
|
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
|
||||||
|
newRouteRules[ruleID] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,102 +323,202 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", ErrSourceRangesEmpty
|
return nil, ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
var sources []netip.Prefix
|
var sources []netip.Prefix
|
||||||
for _, sourceRange := range rule.SourceRanges {
|
for _, sourceRange := range rule.SourceRanges {
|
||||||
source, err := netip.ParsePrefix(sourceRange)
|
source, err := netip.ParsePrefix(sourceRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("parse source range: %w", err)
|
return nil, fmt.Errorf("parse source range: %w", err)
|
||||||
}
|
}
|
||||||
sources = append(sources, source)
|
sources = append(sources, firewall.UnmapPrefix(source))
|
||||||
}
|
}
|
||||||
|
|
||||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("determine destination: %w", err)
|
return nil, fmt.Errorf("determine destination: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid protocol: %w", err)
|
return nil, fmt.Errorf("invalid protocol: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
action, err := convertFirewallAction(rule.Action)
|
action, err := convertFirewallAction(rule.Action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid action: %w", err)
|
return nil, fmt.Errorf("invalid action: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return nil, fmt.Errorf("add route rule: %w", err)
|
||||||
|
}
|
||||||
|
if addedRule == nil {
|
||||||
|
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
|
||||||
}
|
}
|
||||||
|
|
||||||
return id.RuleID(addedRule.ID()), nil
|
return addedRule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
// splitDenyAccept partitions groups by action so denies can be
|
||||||
r *mgmProto.FirewallRule,
|
// applied before accepts. Order within each bucket is preserved.
|
||||||
ipsetName string,
|
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
|
||||||
) (id.RuleID, []firewall.Rule, error) {
|
for _, g := range groups {
|
||||||
ip, err := extractRuleIP(r)
|
if g.action == mgmProto.RuleAction_DROP {
|
||||||
if err != nil {
|
denies = append(denies, g)
|
||||||
return "", nil, err
|
} else {
|
||||||
|
accepts = append(accepts, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return denies, accepts
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupPeerRules merges single-source rules sharing a selector into
|
||||||
|
// multi-source groups. It splits source-decode failures by action:
|
||||||
|
// denyErr is non-nil when a deny rule could not be decoded, which is a
|
||||||
|
// fail-open risk the caller must treat as fatal for the pass; err
|
||||||
|
// carries the tolerable accept-rule failures the caller can log and
|
||||||
|
// continue past.
|
||||||
|
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
|
||||||
|
var denyMerr, acceptMerr *multierror.Error
|
||||||
|
byKey := make(map[peerRuleKey]*peerRuleGroup)
|
||||||
|
order := make([]peerRuleKey, 0)
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
srcs, decErr := extractRuleSources(r)
|
||||||
|
if decErr != nil {
|
||||||
|
if r.Action == mgmProto.RuleAction_DROP {
|
||||||
|
denyMerr = multierror.Append(denyMerr, decErr)
|
||||||
|
} else {
|
||||||
|
acceptMerr = multierror.Append(acceptMerr, decErr)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// A single FirewallRule normally carries one address family, but
|
||||||
|
// split by family defensively: each backend keys a rule to one
|
||||||
|
// family and would mismatch sources of the other, so a group's
|
||||||
|
// sources must never span families.
|
||||||
|
v4, v6 := splitPrefixesByFamily(srcs)
|
||||||
|
for _, sub := range []struct {
|
||||||
|
isV6 bool
|
||||||
|
sources []netip.Prefix
|
||||||
|
}{{false, v4}, {true, v6}} {
|
||||||
|
if len(sub.sources) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := ruleGroupKey(r, sub.isV6)
|
||||||
|
g, ok := byKey[key]
|
||||||
|
if !ok {
|
||||||
|
g = &peerRuleGroup{
|
||||||
|
direction: r.Direction,
|
||||||
|
action: r.Action,
|
||||||
|
protocol: r.Protocol,
|
||||||
|
port: r.PortInfo,
|
||||||
|
legacyPort: r.Port,
|
||||||
|
policyID: r.PolicyID,
|
||||||
|
}
|
||||||
|
byKey[key] = g
|
||||||
|
order = append(order, key)
|
||||||
|
}
|
||||||
|
g.sources = append(g.sources, sub.sources...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
out := make([]*peerRuleGroup, 0, len(order))
|
||||||
if err != nil {
|
for _, k := range order {
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
out = append(out, byKey[k])
|
||||||
|
}
|
||||||
|
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixIsV6(p netip.Prefix) bool {
|
||||||
|
return p.Addr().Is6() && !p.Addr().Is4In6()
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
|
||||||
|
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||||
|
for _, p := range prefixes {
|
||||||
|
if prefixIsV6(p) {
|
||||||
|
v6 = append(v6, p)
|
||||||
|
} else {
|
||||||
|
v4 = append(v4, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v4, v6
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
|
||||||
|
// rule's source family: mgmt emits one rule per family and mixing them
|
||||||
|
// would break ICMP-variant selection in uspfilter.
|
||||||
|
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
|
||||||
|
k := peerRuleKey{
|
||||||
|
v6: v6,
|
||||||
|
policyID: string(r.PolicyID),
|
||||||
|
direction: r.Direction,
|
||||||
|
action: r.Action,
|
||||||
|
protocol: r.Protocol,
|
||||||
|
legacyPort: r.Port,
|
||||||
|
}
|
||||||
|
if pi := r.PortInfo; pi != nil {
|
||||||
|
k.port = uint16(pi.GetPort())
|
||||||
|
if rng := pi.GetRange(); rng != nil {
|
||||||
|
k.rangeStart = uint16(rng.GetStart())
|
||||||
|
k.rangeEnd = uint16(rng.GetEnd())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRuleSources returns all source prefixes the rule applies to.
|
||||||
|
// New management populates sourcePrefixes; older management sets PeerIP.
|
||||||
|
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
|
||||||
|
if len(r.SourcePrefixes) > 0 {
|
||||||
|
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
|
||||||
|
for _, raw := range r.SourcePrefixes {
|
||||||
|
addr, err := netiputil.DecodeAddr(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode source prefix: %w", err)
|
||||||
|
}
|
||||||
|
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
action, err := convertFirewallAction(r.Action)
|
peerIP := r.PeerIP //nolint:staticcheck // PeerIP is the legacy source field for old management servers
|
||||||
|
addr, err := netip.ParseAddr(peerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
return nil, fmt.Errorf("parse peer IP %q: %w", peerIP, err)
|
||||||
}
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
// An unspecified PeerIP means "any peer" (legacy management
|
||||||
|
// allow-all fallback); only a /0 prefix matches any source in the
|
||||||
|
// backends, a full-length prefix would match nothing.
|
||||||
|
if addr.IsUnspecified() {
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(addr, 0)}, nil
|
||||||
|
}
|
||||||
|
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
|
||||||
|
}
|
||||||
|
|
||||||
var port *firewall.Port
|
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
|
||||||
if !portInfoEmpty(r.PortInfo) {
|
if !portInfoEmpty(g.port) {
|
||||||
port = convertPortInfo(r.PortInfo)
|
return convertPortInfo(g.port), nil
|
||||||
} else if r.Port != "" {
|
}
|
||||||
// old version of management, single port
|
if g.legacyPort != "" {
|
||||||
value, err := strconv.Atoi(r.Port)
|
value, err := strconv.ParseUint(g.legacyPort, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
return nil, fmt.Errorf("invalid port: %w", err)
|
||||||
}
|
}
|
||||||
port = &firewall.Port{
|
return &firewall.Port{
|
||||||
Values: []uint16{uint16(value)},
|
Values: []uint16{uint16(value)},
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
// nolint:nilnil // a nil port legitimately means "no port restriction"
|
||||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
return nil, nil
|
||||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
|
||||||
return ruleID, rulesPair, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []firewall.Rule
|
|
||||||
switch r.Direction {
|
|
||||||
case mgmProto.RuleDirection_IN:
|
|
||||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
|
||||||
case mgmProto.RuleDirection_OUT:
|
|
||||||
if d.firewall.IsStateful() {
|
|
||||||
return "", nil, nil
|
|
||||||
}
|
|
||||||
// return traffic for outbound connections if firewall is stateless
|
|
||||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ruleID, rules, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||||
@@ -294,85 +537,9 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
// ConvertToFirewallProtocol maps a management rule protocol to the
|
||||||
id []byte,
|
// firewall protocol type.
|
||||||
ip netip.Addr,
|
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
|
||||||
id []byte,
|
|
||||||
ip netip.Addr,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) ([]firewall.Rule, error) {
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
|
||||||
func (d *DefaultManager) getPeerRuleID(
|
|
||||||
ip netip.Addr,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
direction int,
|
|
||||||
port *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
) id.RuleID {
|
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
|
||||||
if port != nil {
|
|
||||||
idStr += port.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
|
||||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
|
||||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
|
||||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
|
||||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
|
||||||
if len(r.SourcePrefixes) > 0 {
|
|
||||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
|
||||||
}
|
|
||||||
return addr.Unmap(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
|
||||||
addr, err := netip.ParseAddr(r.PeerIP)
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
|
||||||
}
|
|
||||||
return addr.Unmap(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.RuleProtocol_TCP:
|
case mgmProto.RuleProtocol_TCP:
|
||||||
return firewall.ProtocolTCP, nil
|
return firewall.ProtocolTCP, nil
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
existedPairs := map[string]struct{}{}
|
existedPairs := map[fwmanager.RuleID]struct{}{}
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
existedPairs[id.ID()] = struct{}{}
|
existedPairs[id] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
if _, ok := existedPairs[id.ID()]; ok {
|
if _, ok := existedPairs[id]; ok {
|
||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains.
|
// Resolver caches critical NetBird infrastructure domains.
|
||||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||||
|
// guarded by mutex.
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question]*cachedRecord
|
records map[dns.Question]*cachedRecord
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// failedResolves records the last failed initial resolve per domain so a
|
||||||
|
// domain that never resolves isn't retried on every server-domains update
|
||||||
|
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||||
|
// to the current server-domains set.
|
||||||
|
failedResolves map[domain.Domain]time.Time
|
||||||
|
|
||||||
chain ChainResolver
|
chain ChainResolver
|
||||||
chainMaxPriority int
|
chainMaxPriority int
|
||||||
refreshGroup singleflight.Group
|
refreshGroup singleflight.Group
|
||||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
|||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question]*cachedRecord),
|
records: make(map[dns.Question]*cachedRecord),
|
||||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||||
cacheTTL: resolveCacheTTL(),
|
failedResolves: make(map[domain.Domain]time.Time),
|
||||||
|
cacheTTL: resolveCacheTTL(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||||
// entry for that qtype.
|
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||||
|
// the resolved family is still cached but AddDomain returns an error so the
|
||||||
|
// caller retries the incomplete resolve rather than treating it as complete.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
|
if errA != nil || errAAAA != nil {
|
||||||
|
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
delete(m.records, qAAAA)
|
delete(m.records, qAAAA)
|
||||||
delete(m.refreshing, qA)
|
delete(m.refreshing, qA)
|
||||||
delete(m.refreshing, qAAAA)
|
delete(m.refreshing, qAAAA)
|
||||||
|
delete(m.failedResolves, d)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
|||||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||||
currentDomains := m.GetCachedDomains()
|
currentDomains := m.GetCachedDomains()
|
||||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||||
|
m.pruneFailedResolves(allDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addNewDomains(ctx, newDomains)
|
m.addNewDomains(ctx, newDomains)
|
||||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
|||||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
// addNewDomains resolves and caches all domains from the update
|
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||||
|
// running the lookups concurrently. Domains already cached are skipped and left
|
||||||
|
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||||
|
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||||
|
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||||
|
// sync lock on every update.
|
||||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||||
for _, newDomain := range newDomains {
|
for _, newDomain := range newDomains {
|
||||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
if _, dup := seen[newDomain]; dup {
|
||||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
continue
|
||||||
} else {
|
}
|
||||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
seen[newDomain] = struct{}{}
|
||||||
|
|
||||||
|
if !m.needsResolve(newDomain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(d domain.Domain) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
m.markResolveFailed(d)
|
||||||
|
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.clearResolveFailed(d)
|
||||||
|
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||||
|
}(newDomain)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||||
|
// incomplete resolve gates retries on the backoff even when one family is
|
||||||
|
// already cached, so a transiently-failed family is retried instead of being
|
||||||
|
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||||
|
// to the stale-while-revalidate refresh path.
|
||||||
|
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||||
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if failedAt, ok := m.failedResolves[d]; ok {
|
||||||
|
return time.Since(failedAt) >= refreshBackoff
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||||
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||||
|
if _, ok := m.records[q]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.failedResolves[d] = time.Now()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
delete(m.failedResolves, d)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||||
|
// the server-domains set, keeping the map bounded to the current set (a
|
||||||
|
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||||
|
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
for d := range m.failedResolves {
|
||||||
|
if !slices.Contains(domains, d) {
|
||||||
|
delete(m.failedResolves, d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
calls map[string]int
|
calls map[string]int
|
||||||
answers map[string][]dns.RR
|
answers map[string][]dns.RR
|
||||||
|
qErr map[string]error
|
||||||
err error
|
err error
|
||||||
hasRoot bool
|
hasRoot bool
|
||||||
onLookup func()
|
onLookup func()
|
||||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
|||||||
return &fakeChain{
|
return &fakeChain{
|
||||||
calls: map[string]int{},
|
calls: map[string]int{},
|
||||||
answers: map[string][]dns.RR{},
|
answers: map[string][]dns.RR{},
|
||||||
|
qErr: map[string]error{},
|
||||||
hasRoot: true,
|
hasRoot: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
|||||||
f.calls[key]++
|
f.calls[key]++
|
||||||
answers := f.answers[key]
|
answers := f.answers[key]
|
||||||
err := f.err
|
err := f.err
|
||||||
|
if err == nil {
|
||||||
|
err = f.qErr[key]
|
||||||
|
}
|
||||||
onLookup := f.onLookup
|
onLookup := f.onLookup
|
||||||
f.mu.Unlock()
|
f.mu.Unlock()
|
||||||
|
|
||||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|||||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||||
|
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||||
|
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"first update must resolve the domain")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"cached domain must not be re-resolved on a subsequent update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// New domains in a single update must resolve concurrently rather than serially.
|
||||||
|
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
|
||||||
|
var inflight, maxInflight atomic.Int32
|
||||||
|
chain.onLookup = func() {
|
||||||
|
n := inflight.Add(1)
|
||||||
|
for {
|
||||||
|
old := maxInflight.Load()
|
||||||
|
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
inflight.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||||
|
for _, d := range relays {
|
||||||
|
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||||
|
}
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||||
|
require.NoError(t, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||||
|
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||||
|
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A domain that fails to resolve must not be retried on every update; the
|
||||||
|
// failure backoff suppresses re-resolution until it expires.
|
||||||
|
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("resolve boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"first update must attempt the resolve")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"failed resolve must back off and not retry on the next update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||||
|
// the same host) must be resolved once per update, not once per occurrence.
|
||||||
|
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{
|
||||||
|
Stuns: []domain.Domain{"dup.example.com"},
|
||||||
|
Turns: []domain.Domain{"dup.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||||
|
"a domain appearing under multiple server-domain types must resolve once")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||||
|
// so the map stays bounded to the current set.
|
||||||
|
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("resolve boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.True(t, marked, "failed resolve must be recorded")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When one family hard-errors while the other resolves, the domain is cached
|
||||||
|
// for the working family but recorded as incomplete so the failed family is
|
||||||
|
// retried under backoff instead of being treated as fully resolved forever.
|
||||||
|
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||||
|
d := domain.Domain("relay.example.com")
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||||
|
_, marked := r.failedResolves[d]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.True(t, aCached, "the working family must still be cached")
|
||||||
|
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||||
|
|
||||||
|
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||||
|
r.mutex.Unlock()
|
||||||
|
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||||
|
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||||
|
// re-resolved on every sync.
|
||||||
|
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||||
|
d := domain.Domain("v4only.example.com")
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, marked := r.failedResolves[d]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||||
|
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||||
|
}
|
||||||
485
client/internal/dns/server_privileged_test.go
Normal file
485
client/internal/dns/server_privileged_test.go
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initUpstreamMap []handlerWrapper
|
||||||
|
initLocalZones []nbdns.CustomZone
|
||||||
|
initSerial uint64
|
||||||
|
inputSerial uint64
|
||||||
|
inputUpdate nbdns.Config
|
||||||
|
shouldFail bool
|
||||||
|
expectedUpstreamMap []handlerWrapper
|
||||||
|
expectedLocalQs []dns.Question
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Initial Config Should Succeed",
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New Config Should Succeed",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 2,
|
||||||
|
inputSerial: 1,
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{{
|
||||||
|
domain: ".",
|
||||||
|
priority: PriorityDefault,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disabled Service Should clean map",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||||
|
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wgIface.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = dnsServer.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||||
|
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||||
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
|
if err != nil {
|
||||||
|
if testCase.shouldFail {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||||
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range testCase.expectedUpstreamMap {
|
||||||
|
found := false
|
||||||
|
for _, got := range dnsServer.dnsMuxHandlers {
|
||||||
|
if got.domain == expected.domain && got.priority == expected.priority {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, q := range testCase.expectedLocalQs {
|
||||||
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||||
|
Question: []dns.Question{q},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testCase.expectedLocalQs) > 0 {
|
||||||
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||||
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create stdnet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: "utun2301",
|
||||||
|
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create and init wireguard interface: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = wgIface.Close(); err != nil {
|
||||||
|
t.Logf("close wireguard interface: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("run DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||||
|
t.Logf("restore DNS settings on the host: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &local.Resolver{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||||
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server with regular configuration
|
||||||
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := update
|
||||||
|
update2.ServiceEnable = false
|
||||||
|
// Disable the server, stop the listener
|
||||||
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update3 := update2
|
||||||
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||||
|
// But service still get updates and we checking that we handle
|
||||||
|
// internal state in the right way
|
||||||
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -23,7 +22,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -104,466 +102,6 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
initUpstreamMap []handlerWrapper
|
|
||||||
initLocalZones []nbdns.CustomZone
|
|
||||||
initSerial uint64
|
|
||||||
inputSerial uint64
|
|
||||||
inputUpdate nbdns.Config
|
|
||||||
shouldFail bool
|
|
||||||
expectedUpstreamMap []handlerWrapper
|
|
||||||
expectedLocalQs []dns.Question
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Initial Config Should Succeed",
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.io",
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
priority: PriorityDefault,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "New Config Should Succeed",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.io",
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 2,
|
|
||||||
inputSerial: 1,
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{{
|
|
||||||
domain: ".",
|
|
||||||
priority: PriorityDefault,
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
|
||||||
expectedUpstreamMap: nil,
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Disabled Service Should clean map",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
|
||||||
expectedUpstreamMap: nil,
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
|
||||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = wgIface.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = dnsServer.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
|
||||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
|
||||||
if err != nil {
|
|
||||||
if testCase.shouldFail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
|
||||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expected := range testCase.expectedUpstreamMap {
|
|
||||||
found := false
|
|
||||||
for _, got := range dnsServer.dnsMuxHandlers {
|
|
||||||
if got.domain == expected.domain && got.priority == expected.priority {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
|
||||||
responseWriter := &test.MockResponseWriter{
|
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
responseMSG = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, q := range testCase.expectedLocalQs {
|
|
||||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
|
||||||
Question: []dns.Question{q},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(testCase.expectedLocalQs) > 0 {
|
|
||||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
|
||||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create stdnet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: "utun2301",
|
|
||||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("build interface wireguard: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create and init wireguard interface: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = wgIface.Close(); err != nil {
|
|
||||||
t.Logf("close wireguard interface: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
||||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
||||||
t.Errorf("set packet filter: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("run DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
||||||
t.Logf("restore DNS settings on the host: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &local.Resolver{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
|
||||||
dnsServer.updateSerial = 0
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server with regular configuration
|
|
||||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update2 := update
|
|
||||||
update2.ServiceEnable = false
|
|
||||||
// Disable the server, stop the listener
|
|
||||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update3 := update2
|
|
||||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
||||||
// But service still get updates and we checking that we handle
|
|
||||||
// internal state in the right way
|
|
||||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -885,7 +423,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
pf, err := uspfilter.Create(uspfilter.Config{IFace: wgIface, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package dnsfwd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||||
|
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -174,8 +174,12 @@ func (m *Manager) allowDNSFirewall() error {
|
|||||||
return fmt.Errorf("flush: %w", err)
|
return fmt.Errorf("flush: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.fwRules = dnsRules
|
if dnsRule != nil {
|
||||||
m.tcpRules = tcpRules
|
m.fwRules = []firewall.Rule{dnsRule}
|
||||||
|
}
|
||||||
|
if tcpRule != nil {
|
||||||
|
m.tcpRules = []firewall.Rule{tcpRule}
|
||||||
|
}
|
||||||
|
|
||||||
m.registerNetstackServices()
|
m.registerNetstackServices()
|
||||||
|
|
||||||
@@ -209,12 +213,12 @@ func (m *Manager) unregisterNetstackServices() {
|
|||||||
func (m *Manager) dropDNSFirewall() error {
|
func (m *Manager) dropDNSFirewall() error {
|
||||||
var mErr *multierror.Error
|
var mErr *multierror.Error
|
||||||
for _, rule := range m.fwRules {
|
for _, rule := range m.fwRules {
|
||||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, rule := range m.tcpRules {
|
for _, rule := range m.tcpRules {
|
||||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -673,6 +673,11 @@ func (e *Engine) initFirewall() error {
|
|||||||
return fmt.Errorf("set firewall: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: the firewall backends dedup filter rules by content, so a
|
||||||
|
// management route ACL with identical content would collapse onto the
|
||||||
|
// untracked drop rules installed here, and a later management delete
|
||||||
|
// could remove them. Needs backend refcounting or per-consumer key
|
||||||
|
// namespacing.
|
||||||
if e.config.BlockLANAccess {
|
if e.config.BlockLANAccess {
|
||||||
e.blockLanAccess()
|
e.blockLanAccess()
|
||||||
}
|
}
|
||||||
@@ -685,14 +690,14 @@ func (e *Engine) initFirewall() error {
|
|||||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||||
|
|
||||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
|
firewallManager.Network{},
|
||||||
firewallManager.ProtocolUDP,
|
firewallManager.ProtocolUDP,
|
||||||
nil,
|
nil,
|
||||||
&port,
|
&port,
|
||||||
firewallManager.ActionAccept,
|
firewallManager.ActionAccept,
|
||||||
"",
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -742,7 +747,7 @@ func (e *Engine) blockLanAccess() {
|
|||||||
if network.Addr().Is6() {
|
if network.Addr().Is6() {
|
||||||
source = v6
|
source = v6
|
||||||
}
|
}
|
||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddFilterRule(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{source},
|
[]netip.Prefix{source},
|
||||||
firewallManager.Network{Prefix: network},
|
firewallManager.Network{Prefix: network},
|
||||||
@@ -1066,7 +1071,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
}
|
}
|
||||||
e.checks = checks
|
e.checks = checks
|
||||||
|
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
@@ -1097,6 +1102,20 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||||
|
// can be excluded from the reported network addresses; the interface coming and
|
||||||
|
// going otherwise churns the peer meta on the management server.
|
||||||
|
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||||
|
var ips []netip.Addr
|
||||||
|
if e.config.WgAddr.IP.IsValid() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IP)
|
||||||
|
}
|
||||||
|
if e.config.WgAddr.HasIPv6() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IPv6)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
@@ -1240,7 +1259,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
@@ -2441,7 +2460,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
||||||
continue
|
continue
|
||||||
|
|||||||
565
client/internal/engine_privileged_test.go
Normal file
565
client/internal/engine_privileged_test.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEngine_SSH(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(
|
||||||
|
ctx, cancel,
|
||||||
|
&EngineConfig{
|
||||||
|
WgIfaceName: "utun101",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
ServerSSHAllowed: true,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
SSHKey: sshKey,
|
||||||
|
},
|
||||||
|
EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
},
|
||||||
|
MobileDependency{},
|
||||||
|
)
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.21/24"},
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||||
|
networkMap := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 6,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// SSH server is enabled, therefore SSH config should be applied
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 7,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now remove peer
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 8,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now disable SSH server
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 9,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Sync(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// feed updates to Engine via mocked Management client
|
||||||
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
|
defer close(updates)
|
||||||
|
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
|
for msg := range updates {
|
||||||
|
err := msgHandler(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
|
WgIfaceName: "utun103",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.10/24"},
|
||||||
|
}
|
||||||
|
peer2 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.11/24"},
|
||||||
|
}
|
||||||
|
peer3 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.12/24"},
|
||||||
|
}
|
||||||
|
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||||
|
updates <- &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.After(time.Second * 2)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("timeout while waiting for test to finish")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_MultiplePeers(t *testing.T) {
|
||||||
|
// log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigServer, signalAddr, err := startSignal(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sigServer.Stop()
|
||||||
|
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
|
mu := sync.Mutex{}
|
||||||
|
engines := []*Engine{}
|
||||||
|
numPeers := 10
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(numPeers)
|
||||||
|
// create and start peers
|
||||||
|
for i := 0; i < numPeers; i++ {
|
||||||
|
j := i
|
||||||
|
go func() {
|
||||||
|
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||||
|
if err != nil {
|
||||||
|
wg.Done()
|
||||||
|
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engine.dnsServer = &dns.MockServer{}
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engines = append(engines, engine)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until all have been created and started
|
||||||
|
wg.Wait()
|
||||||
|
if len(engines) != numPeers {
|
||||||
|
t.Fatal("not all peers were started")
|
||||||
|
}
|
||||||
|
// check whether all the peer have expected peers connected
|
||||||
|
|
||||||
|
expectedConnected := numPeers * (numPeers - 1)
|
||||||
|
|
||||||
|
// adjust according to timeouts
|
||||||
|
timeout := 50 * time.Second
|
||||||
|
timeoutChan := time.After(timeout)
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||||
|
break loop
|
||||||
|
case <-ticker.C:
|
||||||
|
totalConnected := 0
|
||||||
|
for _, engine := range engines {
|
||||||
|
totalConnected += getConnectedPeers(engine)
|
||||||
|
}
|
||||||
|
if totalConnected == expectedConnected {
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cleanup test
|
||||||
|
for n, peerEngine := range engines {
|
||||||
|
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||||
|
errStop := peerEngine.mgmClient.Close()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||||
|
}
|
||||||
|
errStop = peerEngine.Stop()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
kaep = keepalive.EnforcementPolicy{
|
||||||
|
MinTime: 15 * time.Second,
|
||||||
|
PermitWithoutStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kasp = keepalive.ServerParameters{
|
||||||
|
MaxConnectionIdle: 15 * time.Second,
|
||||||
|
MaxConnectionAgeGrace: 5 * time.Second,
|
||||||
|
Time: 5 * time.Second,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ifaceName string
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||||
|
} else {
|
||||||
|
ifaceName = fmt.Sprintf("wt%d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgPort := 33100 + i
|
||||||
|
conf := &EngineConfig{
|
||||||
|
WgIfaceName: ifaceName,
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: wgPort,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmtClient,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{}), nil
|
||||||
|
e.ctx = ctx
|
||||||
|
return e, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
config := &config.Config{
|
||||||
|
Stuns: []*config.Host{},
|
||||||
|
TURNConfig: &config.TURNConfig{},
|
||||||
|
Relay: &config.Relay{
|
||||||
|
Addresses: []string{"127.0.0.1:1234"},
|
||||||
|
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||||
|
Secret: "222222222222222222",
|
||||||
|
},
|
||||||
|
Signal: &config.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: "localhost:10000",
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.ExtraSettings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
|
func getConnectedPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
i := 0
|
||||||
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
|
if conn.IsConnected() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
return len(e.peerStore.PeersPubKey())
|
||||||
|
}
|
||||||
@@ -6,37 +6,18 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -50,18 +31,7 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
@@ -69,25 +39,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
kaep = keepalive.EnforcementPolicy{
|
|
||||||
MinTime: 15 * time.Second,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
kasp = keepalive.ServerParameters{
|
|
||||||
MaxConnectionIdle: 15 * time.Second,
|
|
||||||
MaxConnectionAgeGrace: 5 * time.Second,
|
|
||||||
Time: 5 * time.Second,
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockWGIface struct {
|
type MockWGIface struct {
|
||||||
CreateFunc func() error
|
CreateFunc func() error
|
||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||||
@@ -234,129 +188,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(
|
|
||||||
ctx, cancel,
|
|
||||||
&EngineConfig{
|
|
||||||
WgIfaceName: "utun101",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
ServerSSHAllowed: true,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
SSHKey: sshKey,
|
|
||||||
},
|
|
||||||
EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
},
|
|
||||||
MobileDependency{},
|
|
||||||
)
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.21/24"},
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
|
||||||
networkMap := &mgmtProto.NetworkMap{
|
|
||||||
Serial: 6,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// SSH server is enabled, therefore SSH config should be applied
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 7,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
JwtConfig: &mgmtProto.JWTConfig{
|
|
||||||
Issuer: "test-issuer",
|
|
||||||
Audience: "test-audience",
|
|
||||||
KeysLocation: "test-keys",
|
|
||||||
MaxTokenAge: 3600,
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now remove peer
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 8,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now disable SSH server
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 9,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||||
// Test that SSH server start/stop logic works based on config
|
// Test that SSH server start/stop logic works based on config
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
@@ -631,97 +462,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_Sync(t *testing.T) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// feed updates to Engine via mocked Management client
|
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
|
||||||
defer close(updates)
|
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
|
||||||
for msg := range updates {
|
|
||||||
err := msgHandler(msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
|
||||||
WgIfaceName: "utun103",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}, EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{})
|
|
||||||
engine.ctx = ctx
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
peer1 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.10/24"},
|
|
||||||
}
|
|
||||||
peer2 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.11/24"},
|
|
||||||
}
|
|
||||||
peer3 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.12/24"},
|
|
||||||
}
|
|
||||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
|
||||||
updates <- &mgmtProto.SyncResponse{
|
|
||||||
NetworkMap: &mgmtProto.NetworkMap{
|
|
||||||
Serial: 10,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := time.After(time.Second * 2)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
t.Fatalf("timeout while waiting for test to finish")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1105,104 +845,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_MultiplePeers(t *testing.T) {
|
|
||||||
// log.SetLevel(log.DebugLevel)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigServer, signalAddr, err := startSignal(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer sigServer.Stop()
|
|
||||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer mgmtServer.GracefulStop()
|
|
||||||
|
|
||||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
|
||||||
|
|
||||||
mu := sync.Mutex{}
|
|
||||||
engines := []*Engine{}
|
|
||||||
numPeers := 10
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
wg.Add(numPeers)
|
|
||||||
// create and start peers
|
|
||||||
for i := 0; i < numPeers; i++ {
|
|
||||||
j := i
|
|
||||||
go func() {
|
|
||||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
|
||||||
if err != nil {
|
|
||||||
wg.Done()
|
|
||||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engine.dnsServer = &dns.MockServer{}
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
|
||||||
wg.Done()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engines = append(engines, engine)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until all have been created and started
|
|
||||||
wg.Wait()
|
|
||||||
if len(engines) != numPeers {
|
|
||||||
t.Fatal("not all peers was started")
|
|
||||||
}
|
|
||||||
// check whether all the peer have expected peers connected
|
|
||||||
|
|
||||||
expectedConnected := numPeers * (numPeers - 1)
|
|
||||||
|
|
||||||
// adjust according to timeouts
|
|
||||||
timeout := 50 * time.Second
|
|
||||||
timeoutChan := time.After(timeout)
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
|
||||||
break loop
|
|
||||||
case <-ticker.C:
|
|
||||||
totalConnected := 0
|
|
||||||
for _, engine := range engines {
|
|
||||||
totalConnected += getConnectedPeers(engine)
|
|
||||||
}
|
|
||||||
if totalConnected == expectedConnected {
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// cleanup test
|
|
||||||
for n, peerEngine := range engines {
|
|
||||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
|
||||||
errStop := peerEngine.mgmClient.Close()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
|
||||||
}
|
|
||||||
errStop = peerEngine.Stop()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||||
ifaceList, err := net.Interfaces()
|
ifaceList, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1526,187 +1168,6 @@ func TestCompareNetIPLists(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ifaceName string
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
|
||||||
} else {
|
|
||||||
ifaceName = fmt.Sprintf("wt%d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgPort := 33100 + i
|
|
||||||
conf := &EngineConfig{
|
|
||||||
WgIfaceName: ifaceName,
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: wgPort,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
|
||||||
SignalClient: signalClient,
|
|
||||||
MgmClient: mgmtClient,
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{}), nil
|
|
||||||
e.ctx = ctx
|
|
||||||
return e, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
|
||||||
require.NoError(t, err)
|
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
config := &config.Config{
|
|
||||||
Stuns: []*config.Host{},
|
|
||||||
TURNConfig: &config.TURNConfig{},
|
|
||||||
Relay: &config.Relay{
|
|
||||||
Addresses: []string{"127.0.0.1:1234"},
|
|
||||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
|
||||||
Secret: "222222222222222222",
|
|
||||||
},
|
|
||||||
Signal: &config.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: "localhost:10000",
|
|
||||||
},
|
|
||||||
Datadir: dataDir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.Settings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.ExtraSettings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func getConnectedPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
i := 0
|
|
||||||
for _, id := range e.peerStore.PeersPubKey() {
|
|
||||||
conn, _ := e.peerStore.PeerConn(id)
|
|
||||||
if conn.IsConnected() {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
return len(e.peerStore.PeersPubKey())
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
b, err := netiputil.EncodePrefix(p)
|
b, err := netiputil.EncodePrefix(p)
|
||||||
|
|||||||
@@ -24,14 +24,14 @@ type RulePair struct {
|
|||||||
type Manager struct {
|
type Manager struct {
|
||||||
dnatFirewall DNATFirewall
|
dnatFirewall DNATFirewall
|
||||||
|
|
||||||
rules map[string]RulePair // keys is the ID of the ForwardRule
|
rules map[firewall.RuleID]RulePair
|
||||||
rulesMu sync.Mutex
|
rulesMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
dnatFirewall: dnatFirewall,
|
dnatFirewall: dnatFirewall,
|
||||||
rules: make(map[string]RulePair),
|
rules: make(map[firewall.RuleID]RulePair),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
|||||||
|
|
||||||
var mErr *multierror.Error
|
var mErr *multierror.Error
|
||||||
|
|
||||||
toDelete := make(map[string]RulePair, len(h.rules))
|
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
|
||||||
for id, r := range h.rules {
|
for id, r := range h.rules {
|
||||||
toDelete[id] = r
|
toDelete[id] = r
|
||||||
}
|
}
|
||||||
@@ -59,6 +59,10 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
|||||||
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if rule == nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
log.Infof("forward rule has been added '%s'", fwdRule)
|
log.Infof("forward rule has been added '%s'", fwdRule)
|
||||||
h.rules[id] = RulePair{
|
h.rules[id] = RulePair{
|
||||||
ForwardRule: fwdRule,
|
ForwardRule: fwdRule,
|
||||||
@@ -90,7 +94,7 @@ func (h *Manager) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.rules = make(map[string]RulePair)
|
h.rules = make(map[firewall.RuleID]RulePair)
|
||||||
return nberrors.FormatErrorOrNil(mErr)
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MocFwRule struct {
|
type MocFwRule struct {
|
||||||
id string
|
id firewall.RuleID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MocFwRule) ID() string {
|
func (m *MocFwRule) ID() firewall.RuleID {
|
||||||
return string(m.id)
|
return m.id
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockDNATFirewall struct {
|
type MockDNATFirewall struct {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user