mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-29 03:09:56 +00:00
Compare commits
21 Commits
fix/mgm-he
...
nmap/compo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1347ad1ff | ||
|
|
db76f33b71 | ||
|
|
eab0826b4e | ||
|
|
7048b87931 | ||
|
|
596952265d | ||
|
|
21cfec93d4 | ||
|
|
98818e3095 | ||
|
|
5d5c2d9f95 | ||
|
|
13e41e432c | ||
|
|
efa6a3f502 | ||
|
|
5fbcdeceac | ||
|
|
3a1bbeba90 | ||
|
|
728057ef15 | ||
|
|
582cd70086 | ||
|
|
9bbbafaf69 | ||
|
|
672b057aa0 | ||
|
|
b9a0186200 | ||
|
|
9083bdb977 | ||
|
|
b194af48b8 | ||
|
|
4543780ef0 | ||
|
|
2de0283971 |
@@ -64,7 +64,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -21,13 +21,13 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
|
||||
20
.github/workflows/golang-test-freebsd.yml
vendored
20
.github/workflows/golang-test-freebsd.yml
vendored
@@ -48,14 +48,14 @@ jobs:
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -tags privileged -timeout 1m -failfast ./base62/...
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./dns/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./encryption/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./formatter/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./route/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./util/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./version/...
|
||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
50
.github/workflows/golang-test-linux.yml
vendored
50
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,7 +158,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -251,7 +251,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -266,7 +266,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -311,7 +311,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -325,7 +325,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -368,7 +368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -383,7 +383,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -429,7 +429,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -440,7 +440,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -534,7 +534,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -545,7 +545,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -579,11 +579,10 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -629,7 +628,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -640,7 +639,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -674,13 +673,12 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
@@ -699,7 +697,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -710,7 +708,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
run: |
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
|
||||
- name: test
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
display_name: Linux
|
||||
name: ${{ matrix.display_name }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -62,4 +62,4 @@ jobs:
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
args: --timeout=12m
|
||||
args: --timeout=20m
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,13 +28,13 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -166,12 +166,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -374,12 +374,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -469,12 +469,12 @@ jobs:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
|
||||
@@ -73,12 +73,12 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
@@ -25,15 +25,3 @@ setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
|
||||
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||
test-unit:
|
||||
@go test -tags devcert -timeout 10m ./...
|
||||
|
||||
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||
# Narrow the run with env vars, e.g.:
|
||||
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||
test-privileged:
|
||||
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||
|
||||
@@ -37,11 +37,6 @@
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
> ### 🤖 NetBird Agent Network (Beta)
|
||||
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# NetBird Agent Network
|
||||
|
||||
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||
policy, with no API keys to leak.
|
||||
|
||||
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||
> infrastructure.
|
||||
|
||||
## How it works
|
||||
|
||||
Agent Network is built on two existing NetBird capabilities:
|
||||
|
||||
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||
key server-side, forwards to the API or gateway, and records usage.
|
||||
|
||||
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||
components:
|
||||
|
||||
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||
and runs the per-request middleware pipeline.
|
||||
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||
and usage/access logs.
|
||||
|
||||
## Documentation
|
||||
|
||||
Full documentation, architecture, and quickstart:
|
||||
**https://docs.netbird.io/agent-network**
|
||||
@@ -1,196 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -27,6 +31,186 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux || !privileged
|
||||
//go:build !linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
@@ -26,6 +26,64 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
@@ -198,64 +256,6 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -102,6 +104,466 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -63,7 +63,9 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
types "github.com/netbirdio/netbird/shared/management/types"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -210,6 +212,13 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||
// NetworkMap that Calculate() produced from it so future incremental
|
||||
// updates have a base to apply changes against. nil for legacy-format
|
||||
// peers. Guarded by syncMsgMux.
|
||||
latestComponents *types.NetworkMapComponents
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -910,20 +919,54 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode the network map from either the components envelope or the
|
||||
// legacy proto.NetworkMap before the posture-check gating below, so the
|
||||
// "is there a network map" decision covers both wire shapes.
|
||||
var (
|
||||
nm *mgmProto.NetworkMap
|
||||
components *types.NetworkMapComponents
|
||||
)
|
||||
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||
// Components-format peer: decode the envelope back to typed
|
||||
// components, run Calculate() locally, and convert to the wire
|
||||
// NetworkMap shape the rest of the engine consumes. Components are
|
||||
// retained so future incremental updates can apply deltas instead
|
||||
// of doing a full reconstruction.
|
||||
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
dnsName := ""
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||
// shared domain by stripping the peer's own label prefix. Falls
|
||||
// back to empty if the FQDN doesn't have the expected shape.
|
||||
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||
}
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode network map envelope: %w", err)
|
||||
}
|
||||
nm = result.NetworkMap
|
||||
components = result.Components
|
||||
} else {
|
||||
nm = update.GetNetworkMap()
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -932,6 +975,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only retain the components view when the server sent the envelope
|
||||
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||
// here would clobber a previously-cached snapshot, breaking the
|
||||
// incremental-delta base on a future envelope sync.
|
||||
if components != nil {
|
||||
e.latestComponents = components
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
@@ -944,6 +995,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||
// receiving peer's FQDN — the same value the management server fills as
|
||||
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||
for i := 0; i < len(fqdn); i++ {
|
||||
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||
return fqdn[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
@@ -1066,7 +1130,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
@@ -1097,20 +1161,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||
// can be excluded from the reported network addresses; the interface coming and
|
||||
// going otherwise churns the peer meta on the management server.
|
||||
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||
var ips []netip.Addr
|
||||
if e.config.WgAddr.IP.IsValid() {
|
||||
ips = append(ips, e.config.WgAddr.IP)
|
||||
}
|
||||
if e.config.WgAddr.HasIPv6() {
|
||||
ips = append(ips, e.config.WgAddr.IPv6)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -1254,7 +1304,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
|
||||
@@ -1,565 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers were started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
@@ -6,18 +6,37 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -31,7 +50,18 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
@@ -39,9 +69,25 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
@@ -188,6 +234,129 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
@@ -462,6 +631,97 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -845,6 +1105,104 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
ifaceList, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -1168,6 +1526,187 @@ func TestCompareNetIPLists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
|
||||
@@ -119,6 +119,10 @@ func (d *BindListener) ReadPackets() {
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.lazyConn.Close()
|
||||
d.bind.RemoveEndpoint(d.fakeIP)
|
||||
d.done.Done()
|
||||
|
||||
@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
|
||||
}
|
||||
|
||||
offer := h.buildOfferAnswer()
|
||||
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
|
||||
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||
}
|
||||
|
||||
func (h *Handshaker) sendAnswer() error {
|
||||
answer := h.buildOfferAnswer()
|
||||
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
|
||||
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalAnswer(answer, h.config.Key)
|
||||
}
|
||||
|
||||
@@ -192,7 +192,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
muxRelays sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
@@ -245,8 +244,8 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
}
|
||||
|
||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayMgr = manager
|
||||
}
|
||||
|
||||
@@ -907,8 +906,8 @@ func (d *Status) MarkSignalConnected() {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
@@ -1019,26 +1018,24 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.muxRelays.RLock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.relayMgr == nil {
|
||||
defer d.muxRelays.RUnlock()
|
||||
return slices.Clone(d.relayStates)
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
relayMgr := d.relayMgr
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
d.muxRelays.RUnlock()
|
||||
|
||||
states := relayMgr.RelayStates()
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range relayMgr.ServerURLs() {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
|
||||
@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
|
||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
||||
if *input.ServerSSHAllowed {
|
||||
log.Infof("enabling SSH server")
|
||||
} else {
|
||||
|
||||
@@ -242,35 +242,6 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
|
||||
// Configs written before ServerSSHAllowed was introduced lack the field and
|
||||
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
|
||||
// apply the value instead of panicking on a nil pointer dereference.
|
||||
tests := []struct {
|
||||
name string
|
||||
input *bool
|
||||
want bool
|
||||
}{
|
||||
{"enable", util.True(), true},
|
||||
{"disable", util.False(), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
|
||||
|
||||
config, err := UpdateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
ServerSSHAllowed: tt.input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
|
||||
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build privileged
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
@@ -3,24 +3,79 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "utun100"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "lo0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "lo0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -67,3 +122,122 @@ func TestBits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && !ios && privileged
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -26,6 +26,11 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -510,3 +515,125 @@ func setupTestEnv(t *testing.T) {
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@@ -15,6 +18,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
@@ -26,6 +33,62 @@ func init() {
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "dummyext0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "dummyint0"
|
||||
@@ -1,83 +0,0 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||
// BSD/darwin test files compile without the privileged build tag.
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -20,6 +20,63 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
nbnet.Init()
|
||||
for _, tc := range testCases {
|
||||
@@ -45,6 +102,16 @@ func TestRouting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build windows && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android && privileged
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -8,14 +8,11 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -67,7 +67,6 @@ var boolStringLiterals = map[string]bool{
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
|
||||
@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
@@ -2,22 +2,124 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
@@ -157,3 +259,119 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
@@ -77,6 +78,53 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
@@ -106,6 +154,59 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
//go:build privileged
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -1,12 +1,25 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
@@ -15,7 +28,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -89,6 +106,331 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
// Configure SSH authorization for the test user
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
@@ -150,6 +492,10 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
@@ -162,3 +508,63 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": user,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
//go:build unix && privileged
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
@@ -73,6 +73,61 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
|
||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||
// This test requires root privileges and will be skipped if not running as root
|
||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -122,23 +121,6 @@ func (i *Info) SetFlags(
|
||||
}
|
||||
}
|
||||
|
||||
// removeAddresses drops network addresses whose IP matches any of the given
|
||||
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
|
||||
// address, which otherwise churns the meta as the interface comes and goes.
|
||||
func (i *Info) removeAddresses(ips ...netip.Addr) {
|
||||
if len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
filtered := i.NetworkAddresses[:0]
|
||||
for _, addr := range i.NetworkAddresses {
|
||||
if slices.Contains(ips, addr.NetIP.Addr()) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, addr)
|
||||
}
|
||||
i.NetworkAddresses = filtered
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
@@ -165,9 +147,7 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
// excludeIPs are dropped from the reported network addresses (e.g. our own
|
||||
// WireGuard overlay address, which otherwise churns the peer meta).
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||
processCheckPaths := make([]string, 0)
|
||||
for _, check := range checks {
|
||||
@@ -182,7 +162,6 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
|
||||
info := GetInfo(ctx)
|
||||
info.Files = files
|
||||
info.removeAddresses(excludeIPs...)
|
||||
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
|
||||
@@ -2,7 +2,6 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -44,42 +43,3 @@ func Test_NetAddresses(t *testing.T) {
|
||||
t.Errorf("no network addresses found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfo_RemoveAddresses(t *testing.T) {
|
||||
addr := func(cidr string) NetworkAddress {
|
||||
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
NetworkAddresses: []NetworkAddress{
|
||||
addr("192.168.1.7/24"),
|
||||
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
|
||||
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
|
||||
addr("fd00:1234::1/64"), // overlay v6
|
||||
},
|
||||
}
|
||||
|
||||
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
|
||||
info.removeAddresses(
|
||||
netip.MustParseAddr("100.76.70.97"),
|
||||
netip.MustParseAddr("fd00:1234::1"),
|
||||
)
|
||||
|
||||
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
|
||||
if len(info.NetworkAddresses) != len(want) {
|
||||
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
|
||||
}
|
||||
for i, w := range want {
|
||||
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
|
||||
t.Errorf("address[%d] = %s, want %s", i, got, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
|
||||
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
|
||||
info.removeAddresses()
|
||||
if len(info.NetworkAddresses) != 1 {
|
||||
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,7 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
// Skip link-local and multicast: they carry no routable peer info and the
|
||||
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
|
||||
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||
if ipNet.IP.IsLoopback() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
|
||||
t.Helper()
|
||||
ip, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %q: %v", cidr, err)
|
||||
}
|
||||
ipNet.IP = ip
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func TestToNetworkAddress_Filtering(t *testing.T) {
|
||||
const mac = "c8:4b:d6:b6:04:ac"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cidr string
|
||||
want bool
|
||||
}{
|
||||
{"ipv4 global", "10.65.16.181/23", true},
|
||||
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
|
||||
{"ipv4 loopback", "127.0.0.1/8", false},
|
||||
{"ipv6 loopback", "::1/128", false},
|
||||
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
|
||||
{"ipv4 link-local", "169.254.1.2/16", false},
|
||||
{"ipv6 multicast", "ff02::1/128", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
|
||||
if got != tt.want {
|
||||
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
//go:build privileged && (linux || darwin)
|
||||
|
||||
// Package privileged provides a self-hosting harness that runs the repo's
|
||||
// privileged-tagged test suite inside a --privileged --cap-add=NET_ADMIN
|
||||
// container, so developers can exercise the root/system-mutating tests on a
|
||||
// non-root host with a single `go test` invocation.
|
||||
package privileged
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/moby/moby/api/types/container"
|
||||
"github.com/ory/dockertest/v4"
|
||||
)
|
||||
|
||||
// containerImage / containerTag match the image used by the CI privileged job
|
||||
// (.github/workflows/golang-test-linux.yml, test_client_on_docker).
|
||||
const (
|
||||
containerImage = "golang"
|
||||
containerTag = "1.25-alpine"
|
||||
)
|
||||
|
||||
const (
|
||||
containerWorkdir = "/app"
|
||||
containerGoCache = "/root/.cache/go-build"
|
||||
containerGoModCache = "/go/pkg/mod"
|
||||
)
|
||||
|
||||
// alpinePackages are the build/runtime deps the privileged tests need, mirroring
|
||||
// the CI container setup.
|
||||
const alpinePackages = "ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base"
|
||||
|
||||
// privilegedTestPackages is the package list the suite runs, excluding the
|
||||
// server-side trees and UI/upload helpers, matching the CI Docker job's filter.
|
||||
const privilegedTestPackages = `go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server`
|
||||
|
||||
// testWriter forwards container output to the test log line by line.
|
||||
type testWriter struct{ t *testing.T }
|
||||
|
||||
func (w testWriter) Write(p []byte) (int, error) {
|
||||
for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
|
||||
w.t.Log(line)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// TestRunPrivilegedSuiteInDocker spins up a privileged container, mounts the repo,
|
||||
// and runs `go test -tags 'devcert privileged'` inside it. When already running
|
||||
// inside that container (DOCKER_CI=true) it returns immediately so the real
|
||||
// privileged tests in the suite execute in place instead of recursing.
|
||||
func TestRunPrivilegedSuiteInDocker(t *testing.T) {
|
||||
if os.Getenv("DOCKER_CI") == "true" {
|
||||
t.Skip("inside privileged container, skipping container spawn; privileged tests run in place")
|
||||
}
|
||||
|
||||
repoRoot, err := findRepoRoot()
|
||||
if err != nil {
|
||||
t.Fatalf("locate repo root: %v", err)
|
||||
}
|
||||
goCache, goModCache := hostGoCaches(t)
|
||||
|
||||
// dockertest reads DOCKER_HOST; point it at the active context's socket when
|
||||
// the default one is absent (macOS Docker Desktop, Colima, OrbStack).
|
||||
if host := dockerHost(); host != "" {
|
||||
t.Setenv("DOCKER_HOST", host)
|
||||
}
|
||||
|
||||
// NewPoolT registers container cleanup via t.Cleanup automatically.
|
||||
pool := dockertest.NewPoolT(t, "", dockertest.WithMaxWait(30*time.Minute))
|
||||
|
||||
// Keep the container alive so the suite runs via Exec, which yields a clean
|
||||
// exit code (the v4 Resource API exposes no container wait/exit-code).
|
||||
resource := pool.RunT(t, containerImage,
|
||||
dockertest.WithTag(containerTag),
|
||||
dockertest.WithWorkingDir(containerWorkdir),
|
||||
dockertest.WithMounts([]string{
|
||||
repoRoot + ":" + containerWorkdir,
|
||||
goCache + ":" + containerGoCache,
|
||||
goModCache + ":" + containerGoModCache,
|
||||
}),
|
||||
dockertest.WithEnv([]string{
|
||||
"CGO_ENABLED=1",
|
||||
"CI=true",
|
||||
"DOCKER_CI=true",
|
||||
"CONTAINER=true",
|
||||
"GOCACHE=" + containerGoCache,
|
||||
"GOMODCACHE=" + containerGoModCache,
|
||||
}),
|
||||
dockertest.WithCmd([]string{"sleep", "infinity"}),
|
||||
dockertest.WithHostConfig(func(hc *container.HostConfig) {
|
||||
hc.Privileged = true
|
||||
hc.CapAdd = []string{"NET_ADMIN"}
|
||||
}),
|
||||
dockertest.WithoutReuse(),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
result, err := resource.Exec(ctx, []string{"sh", "-c", buildTestScript()})
|
||||
if err != nil {
|
||||
t.Fatalf("run privileged suite in container: %v", err)
|
||||
}
|
||||
|
||||
w := testWriter{t}
|
||||
_, _ = w.Write([]byte(result.StdOut))
|
||||
_, _ = w.Write([]byte(result.StdErr))
|
||||
|
||||
if result.ExitCode != 0 {
|
||||
t.Fatalf("privileged test suite failed in container (exit code %d)", result.ExitCode)
|
||||
}
|
||||
}
|
||||
|
||||
// findRepoRoot walks up from the test's working directory to the module root.
|
||||
func findRepoRoot() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for {
|
||||
if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil {
|
||||
return dir, nil
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
return "", fmt.Errorf("go.mod not found above %s", dir)
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// dockerHost returns a DOCKER_HOST override when the default socket is missing.
|
||||
// An empty result means the caller should leave DOCKER_HOST untouched (it is
|
||||
// already set, or the default unix socket exists). When neither is present
|
||||
// (common on macOS Docker Desktop, Colima and OrbStack, which use a per-user
|
||||
// socket), it resolves the active docker context's endpoint.
|
||||
func dockerHost() string {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return ""
|
||||
}
|
||||
if _, err := os.Stat("/var/run/docker.sock"); err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
out, err := exec.Command("docker", "context", "inspect", "-f", "{{.Endpoints.docker.Host}}").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// hostGoCaches resolves the host GOCACHE/GOMODCACHE so the container reuses the
|
||||
// existing build/module cache for speed.
|
||||
func hostGoCaches(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
return goEnv(t, "GOCACHE"), goEnv(t, "GOMODCACHE")
|
||||
}
|
||||
|
||||
func goEnv(t *testing.T, key string) string {
|
||||
t.Helper()
|
||||
var out bytes.Buffer
|
||||
cmd := exec.Command("go", "env", key)
|
||||
cmd.Stdout = &out
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("go env %s: %v", key, err)
|
||||
}
|
||||
return strings.TrimSpace(out.String())
|
||||
}
|
||||
|
||||
// buildTestScript builds the in-container command. PRIV_PKGS overrides the package
|
||||
// list (default: the full filtered set); PRIV_RUN adds a -run test-name filter.
|
||||
// Both empty reproduces the full privileged suite.
|
||||
func buildTestScript() string {
|
||||
pkgs := privilegedTestPackages + " | xargs"
|
||||
if p := os.Getenv("PRIV_PKGS"); p != "" {
|
||||
pkgs = "echo " + p + " | xargs"
|
||||
}
|
||||
|
||||
runFilter := ""
|
||||
if r := os.Getenv("PRIV_RUN"); r != "" {
|
||||
runFilter = "-run '" + r + "' "
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"apk update >/dev/null && apk add --no-cache %s >/dev/null && %s go test -buildvcs=false -tags 'devcert privileged' %s-v -timeout 20m -p 1",
|
||||
alpinePackages, pkgs, runFilter,
|
||||
)
|
||||
}
|
||||
@@ -336,11 +336,11 @@ type serviceClient struct {
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
|
||||
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||
// Name group name
|
||||
Name string
|
||||
// Description group description
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Privileged tests
|
||||
|
||||
Some tests in this repo need `root` or mutate host network state: they create
|
||||
TUN/WireGuard interfaces, open netlink/raw sockets, run eBPF programs, or shell
|
||||
out to `ip`/`iptables`/`nft`/`ifconfig`/`route`. Running them on a developer
|
||||
machine would require `sudo` and could leave stray interfaces or routes behind.
|
||||
|
||||
These tests are gated behind the **`privileged` build tag** so the default test
|
||||
run is host-safe.
|
||||
|
||||
## Running tests
|
||||
|
||||
```bash
|
||||
# Host-safe: excludes privileged tests. Runs as a normal user, no sudo.
|
||||
make test-unit
|
||||
# equivalently:
|
||||
go test -tags devcert ./...
|
||||
|
||||
# Privileged suite: runs the privileged-tagged tests inside a
|
||||
# --privileged --cap-add=NET_ADMIN container (requires Docker).
|
||||
make test-privileged
|
||||
|
||||
# Narrow the container run to a single test / package:
|
||||
PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||
```
|
||||
|
||||
`PRIV_RUN` adds a `-run` test-name filter and `PRIV_PKGS` overrides the package
|
||||
list; both are optional and default to the full privileged suite.
|
||||
|
||||
`make test-privileged` invokes the `ory/dockertest` harness in
|
||||
`client/testutil/privileged/`. The harness:
|
||||
|
||||
1. Skips immediately when it detects it is already inside the container
|
||||
(`DOCKER_CI=true`), so the privileged tests run in place instead of recursing.
|
||||
2. Otherwise spins up a `golang:1.25-alpine` container (matching CI),
|
||||
bind-mounts the repo and the host Go build/module caches, installs the
|
||||
required packages, and runs `go test -tags 'devcert privileged'` over the
|
||||
client packages.
|
||||
3. Streams the container's output to the test log and fails if the suite fails.
|
||||
|
||||
## Adding a privileged test
|
||||
|
||||
A test is privileged if it does any of:
|
||||
|
||||
- creates a real interface via `iface.NewWGIFace(...).Create()`,
|
||||
- opens a netlink or raw socket that hard-fails without `CAP_NET_ADMIN`,
|
||||
- runs an eBPF program (`ebpf.*.Listen()`),
|
||||
- shells out to `ip`, `iptables`, `nft`, `ifconfig`, or `route` to change state.
|
||||
|
||||
Add the tag to the **top** of the file, combined with any existing platform
|
||||
constraint:
|
||||
|
||||
```go
|
||||
//go:build privileged && linux
|
||||
|
||||
package foo
|
||||
```
|
||||
|
||||
If a file mixes privileged and pure-logic tests, **split it**: keep the pure
|
||||
tests (and any shared data — type/var declarations, table-driven `testCases`,
|
||||
helper interfaces) in an untagged file, and move the privileged tests into a
|
||||
`*_privileged_test.go` file with the tag. Shared declarations must stay untagged,
|
||||
otherwise the unprivileged files in the package will not compile.
|
||||
|
||||
Always verify both build modes compile on every target platform:
|
||||
|
||||
```bash
|
||||
go vet -tags devcert ./...
|
||||
go vet -tags 'devcert privileged' ./...
|
||||
```
|
||||
|
||||
## CI
|
||||
|
||||
- The `Client / Unit` job runs `go test -tags devcert` with **no** `sudo` — only
|
||||
host-safe tests.
|
||||
- The `Client (Docker) / Unit` job runs `go test -tags 'devcert privileged'`
|
||||
inside a `--privileged --cap-add=NET_ADMIN` container, which is where the
|
||||
privileged tests actually execute.
|
||||
11
go.mod
11
go.mod
@@ -78,12 +78,10 @@ require (
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/miekg/dns v1.1.72
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/moby/moby/api v1.54.1
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/oapi-codegen/runtime v1.1.2
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/ory/dockertest/v4 v4.0.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
|
||||
@@ -147,7 +145,7 @@ require (
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
filippo.io/edwards25519 v1.1.1 // indirect
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
github.com/Azure/go-ntlmssp v0.1.0 // indirect
|
||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
@@ -179,8 +177,6 @@ require (
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
@@ -275,12 +271,11 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/moby/client v0.4.0 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.5.0 // indirect
|
||||
github.com/moby/sys/user v0.3.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.2 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||
@@ -346,7 +341,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
28
go.sum
28
go.sum
@@ -23,8 +23,8 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
|
||||
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
@@ -117,10 +117,6 @@ github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
||||
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
@@ -484,10 +480,6 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx
|
||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4=
|
||||
github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs=
|
||||
github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw=
|
||||
github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
|
||||
@@ -496,8 +488,8 @@ github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo=
|
||||
github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
@@ -518,8 +510,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a h1:3CWK+yTvRKOcC0Q8VCTGy4l60TEb27CQVS7LkMxwjmw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
@@ -550,8 +542,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/ory/dockertest/v4 v4.0.0 h1:i19aFsO/VXE0VrMk4ifnKW4G/KIJ93PCjLOslxXoPME=
|
||||
github.com/ory/dockertest/v4 v4.0.0/go.mod h1:b5Ofu8VIxWNhXFvQcLu17pRNQdoUBKtXBW74G4Ygzx8=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
@@ -983,13 +973,11 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
if file == "" {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
return newSQLite3(file).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/server/signer"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -79,7 +78,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Initialize SQLite storage
|
||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
||||
sqliteConfig := newSQLite3(dbPath)
|
||||
stor, err := sqliteConfig.Open(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||
|
||||
15
idp/dex/sqlite_cgo.go
Normal file
15
idp/dex/sqlite_cgo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
sql "github.com/dexidp/dex/storage/sql"
|
||||
)
|
||||
|
||||
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
|
||||
// struct that takes a File path. Non-CGO builds get an empty stub whose
|
||||
// Open() returns the dex "SQLite not available" error — correct behaviour
|
||||
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
|
||||
func newSQLite3(file string) *sql.SQLite3 {
|
||||
return &sql.SQLite3{File: file}
|
||||
}
|
||||
15
idp/dex/sqlite_nocgo.go
Normal file
15
idp/dex/sqlite_nocgo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
sql "github.com/dexidp/dex/storage/sql"
|
||||
)
|
||||
|
||||
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
|
||||
// Open() returns an error documenting the missing CGO support — correct
|
||||
// behaviour for cross-compiled artefacts that never actually run the
|
||||
// embedded IdP. The `file` argument is ignored.
|
||||
func newSQLite3(_ string) *sql.SQLite3 {
|
||||
return &sql.SQLite3{}
|
||||
}
|
||||
@@ -351,11 +351,6 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
# Record whether the operator explicitly pinned the server/proxy images via
|
||||
# env vars, so the agent-network preset can pick its own defaults without
|
||||
# clobbering an explicit override.
|
||||
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
@@ -403,53 +398,7 @@ configure_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
apply_agent_network_preset() {
|
||||
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||
# inputs we still need from the operator are the domain (handled by
|
||||
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||
# and the ACME email — both honor env vars first and fall back to a
|
||||
# prompt only when unset. CrowdSec is intentionally off.
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||
# env override; otherwise pin the agent-network builds.
|
||||
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
|
||||
fi
|
||||
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
|
||||
fi
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
fi
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
|
||||
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
|
||||
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||
echo " - CrowdSec: disabled" > /dev/stderr
|
||||
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
configure_reverse_proxy() {
|
||||
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
apply_agent_network_preset
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Prompt for reverse proxy type
|
||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||
|
||||
@@ -961,15 +910,6 @@ NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||
# and exposes only the AI Observability + agent-network configuration
|
||||
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||
EOF
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1006,17 +946,6 @@ NB_PROXY_PROXY_PROTOCOL=true
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||
# ingress for agent-network synth services. Disables the public-facing
|
||||
# surface so the proxy serves only synth-generated routes (the
|
||||
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||
# listeners on the embedded netstack.
|
||||
NB_PROXY_PRIVATE=true
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||
cat <<EOF
|
||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||
@@ -1397,20 +1326,12 @@ print_builtin_traefik_instructions() {
|
||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||
fi
|
||||
echo ""
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.ai/pricing"
|
||||
echo " Documentation: https://docs.netbird.io/agent-network"
|
||||
else
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
fi
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
echo ""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo "NetBird Proxy:"
|
||||
@@ -1433,11 +1354,6 @@ print_builtin_traefik_instructions() {
|
||||
echo ""
|
||||
fi
|
||||
fi
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "Note: The public domain is only for setting up secure connections."
|
||||
echo "Your APIs and agent services remain private and are never exposed publicly."
|
||||
echo ""
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,12 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
// componentsDisabled, when true, forces the controller to emit legacy
|
||||
// proto.NetworkMap to every peer regardless of capability. Set once at
|
||||
// construction and never written after — readers race-free without a
|
||||
// mutex.
|
||||
componentsDisabled bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -89,12 +95,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
}
|
||||
}
|
||||
|
||||
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||
// component-based wire format for this peer.
|
||||
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||
}
|
||||
|
||||
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||
// literal.
|
||||
func parseBoolEnv(key string) bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -204,18 +225,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||
// the client merges it into Calculate()'s output the same
|
||||
// way the legacy server did via NetworkMap.Merge.
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
@@ -425,11 +454,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
@@ -440,7 +469,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
@@ -487,6 +521,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||
// encodes both into the wire envelope. Callers must gate on capability
|
||||
// themselves before dispatching here — this method does NOT branch on it.
|
||||
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
// Fetch the proxy network map fragment for this peer alongside the
|
||||
// components — same single-account-load path the streaming controller
|
||||
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||
// instead of waiting for the next streaming push.
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
|
||||
@@ -24,6 +24,10 @@ type Controller interface {
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
@@ -143,6 +143,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||
ret2, _ := ret[2].(*types.NetworkMap)
|
||||
ret3, _ := ret[3].([]*posture.Checks)
|
||||
ret4, _ := ret[4].(int64)
|
||||
ret5, _ := ret[5].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// PeerNeedsComponents mocks base method.
|
||||
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
813
management/internals/shared/grpc/components_encoder.go
Normal file
813
management/internals/shared/grpc/components_encoder.go
Normal file
@@ -0,0 +1,813 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||
const wgKeyRawLen = 32
|
||||
|
||||
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||
// The envelope is fully self-contained — every field needed by the client's
|
||||
// local Calculate() comes from the components struct itself. The only
|
||||
// externally-supplied data is the receiving peer's PeerConfig (which is
|
||||
// computed alongside the components in the network_map controller and reused
|
||||
// from the legacy proto path) and the dns_domain string.
|
||||
type ComponentsEnvelopeInput struct {
|
||||
Components *types.NetworkMapComponents
|
||||
PeerConfig *proto.PeerConfig
|
||||
DNSDomain string
|
||||
DNSForwarderPort int64
|
||||
// UserIDClaim is the OIDC claim name the client should embed in
|
||||
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||
// is OK — client treats empty as "no SshAuth to build".
|
||||
UserIDClaim string
|
||||
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||
// is present; encoder skips the field in that case.
|
||||
ProxyPatch *proto.ProxyPatch
|
||||
}
|
||||
|
||||
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||
// Go maps in their native (random) order. Indexes inside the envelope
|
||||
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||
// are self-consistent within a single encode, so the decoder reconstructs
|
||||
// the same typed objects regardless of emit order. Tests that need to
|
||||
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||
// not byte-equal.
|
||||
//
|
||||
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||
// index spaces are local to a single envelope.
|
||||
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||
c := in.Components
|
||||
|
||||
// Graceful degrade when components is nil — matches the legacy path's
|
||||
// behaviour for missing/unvalidated peers (return a NetworkMap with only
|
||||
// Network populated). The receiver gets an envelope it can decode
|
||||
// without crashing; AccountSettings stays non-nil so client-side
|
||||
// dereferences are safe.
|
||||
if c == nil {
|
||||
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||
// populated. The receiver gets enough to bootstrap (Network
|
||||
// identifier, dns_domain, account_settings) and nothing else.
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{
|
||||
Full: &proto.NetworkMapComponentsFull{
|
||||
PeerConfig: in.PeerConfig,
|
||||
DnsDomain: in.DNSDomain,
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
AccountSettings: &proto.AccountSettingsCompact{},
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||
// peers that exist only in c.RouterPeers would silently lose their
|
||||
// peer_index reference.
|
||||
enc := newComponentEncoder(c)
|
||||
enc.indexAllPeers()
|
||||
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||
|
||||
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||
// translate every *Policy pointer to a wire index.
|
||||
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||
|
||||
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||
// every encoder either reads from the dedup tables or works on
|
||||
// independent input.
|
||||
full := &proto.NetworkMapComponentsFull{
|
||||
Serial: networkSerial(c.Network),
|
||||
PeerConfig: in.PeerConfig,
|
||||
Network: toAccountNetwork(c.Network),
|
||||
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||
DnsDomain: in.DNSDomain,
|
||||
CustomZoneDomain: c.CustomZoneDomain,
|
||||
AgentVersions: enc.agentVersions,
|
||||
Peers: enc.peers,
|
||||
RouterPeerIndexes: routerIdxs,
|
||||
Policies: policies,
|
||||
Groups: enc.encodeGroups(),
|
||||
Routes: enc.encodeRoutes(c.Routes),
|
||||
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||
AccountZones: encodeCustomZones(c.AccountZones),
|
||||
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||
}
|
||||
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||
}
|
||||
}
|
||||
|
||||
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||
// production path always populates c.Network, but the encoder is exported
|
||||
// and a hand-built components struct may omit it.
|
||||
func networkSerial(n *types.Network) uint64 {
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
return n.CurrentSerial()
|
||||
}
|
||||
|
||||
type componentEncoder struct {
|
||||
components *types.NetworkMapComponents
|
||||
|
||||
peerOrder map[string]uint32
|
||||
peers []*proto.PeerCompact
|
||||
|
||||
agentVersionOrder map[string]uint32
|
||||
agentVersions []string
|
||||
}
|
||||
|
||||
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||
return &componentEncoder{
|
||||
components: c,
|
||||
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||
agentVersionOrder: make(map[string]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) indexAllPeers() {
|
||||
for _, p := range e.components.Peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
e.appendPeer(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := uint32(len(e.peers))
|
||||
e.peerOrder[p.ID] = idx
|
||||
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||
return idx
|
||||
}
|
||||
|
||||
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||
return idx
|
||||
}
|
||||
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||
// a WtVersion don't force the table to materialise.
|
||||
if v == "" {
|
||||
idx := uint32(len(e.agentVersions))
|
||||
if idx == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
}
|
||||
e.agentVersionOrder[""] = idx
|
||||
return idx
|
||||
}
|
||||
if len(e.agentVersions) == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
e.agentVersionOrder[""] = 0
|
||||
}
|
||||
idx := uint32(len(e.agentVersions))
|
||||
e.agentVersionOrder[v] = idx
|
||||
e.agentVersions = append(e.agentVersions, v)
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||
if len(routers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(routers))
|
||||
for _, p := range routers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, e.appendPeer(p))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||
if len(e.components.Groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||
for _, g := range e.components.Groups {
|
||||
if !g.HasSeqID() {
|
||||
continue
|
||||
}
|
||||
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||
for _, peerID := range g.Peers {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
peerIdxs = append(peerIdxs, idx)
|
||||
}
|
||||
}
|
||||
out = append(out, &proto.GroupCompact{
|
||||
Id: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
PeerIndexes: peerIdxs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||
// that list — used by encodeResourcePoliciesMap to translate
|
||||
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||
|
||||
for _, pol := range policies {
|
||||
if !pol.HasSeqID() || !pol.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, r := range pol.Rules {
|
||||
if r == nil || !r.Enabled {
|
||||
continue
|
||||
}
|
||||
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||
out = append(out, e.encodePolicyRule(pol, r))
|
||||
}
|
||||
}
|
||||
return out, idxByPolicy
|
||||
}
|
||||
|
||||
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||
return &proto.PolicyCompact{
|
||||
Id: pol.AccountSeqID,
|
||||
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||
Bidirectional: r.Bidirectional,
|
||||
Ports: portsToUint32(r.Ports),
|
||||
PortRanges: portRangesToProto(r.PortRanges),
|
||||
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||
AuthorizedUser: r.AuthorizedUser,
|
||||
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||
SourceResource: e.resourceToProto(r.SourceResource),
|
||||
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||
}
|
||||
}
|
||||
|
||||
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||
// dropping any group that has no seq id assigned.
|
||||
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(src))
|
||||
for _, gid := range src {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// unionPolicies merges c.Policies with every policy referenced by
|
||||
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||
// from the wire and the client's resource-policy lookup would come back
|
||||
// empty.
|
||||
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||
// Fast path: non-router peers have no resource-only policies, so the
|
||||
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||
if len(resourcePolicies) == 0 {
|
||||
return policies
|
||||
}
|
||||
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
for _, list := range resourcePolicies {
|
||||
for _, p := range list {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||
// group xid → local-user names) to the wire form (map keyed by group
|
||||
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||
// matches how source/destination group references handle the same case.
|
||||
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||
for groupID, names := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||
g, ok := e.components.Groups[groupID]
|
||||
if !ok || !g.HasSeqID() {
|
||||
return 0, false
|
||||
}
|
||||
return g.AccountSeqID, true
|
||||
}
|
||||
|
||||
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||
// resources the peer id is converted to a peer index into the envelope's
|
||||
// peers array. For other resource types only the type string is shipped
|
||||
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||
// for "peer" — other types fall through to group-based lookup).
|
||||
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||
if r.ID == "" && r.Type == "" {
|
||||
return nil
|
||||
}
|
||||
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||
out.PeerIndexSet = true
|
||||
out.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||
// references handle the same case.
|
||||
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(xids))
|
||||
for _, xid := range xids {
|
||||
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// networkSeq translates a Network xid to its per-account integer id using
|
||||
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||
// the xid isn't known — callers decide whether to skip the parent record.
|
||||
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||
if xid == "" {
|
||||
return 0, false
|
||||
}
|
||||
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||
if !ok || seq == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return seq, true
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := &proto.DNSSettingsCompact{
|
||||
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||
}
|
||||
for _, gid := range s.DisabledManagementGroups {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||
if len(routes) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
rr := &proto.RouteRaw{
|
||||
Id: r.AccountSeqID,
|
||||
NetId: string(r.NetID),
|
||||
Description: r.Description,
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: int32(r.NetworkType),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
SkipAutoApply: r.SkipAutoApply,
|
||||
Domains: r.Domains.ToPunycodeList(),
|
||||
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
}
|
||||
if r.Network.IsValid() {
|
||||
rr.NetworkCidr = r.Network.String()
|
||||
}
|
||||
if r.Peer != "" {
|
||||
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||
rr.PeerIndexSet = true
|
||||
rr.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||
if len(nsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||
for _, nsg := range nsgs {
|
||||
if nsg == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NameServerGroupRaw{
|
||||
Id: nsg.AccountSeqID,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Nameservers: encodeNameServers(nsg.NameServers),
|
||||
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServer, 0, len(servers))
|
||||
for _, s := range servers {
|
||||
out = append(out, &proto.NameServer{
|
||||
IP: s.IP.String(),
|
||||
NSType: int64(s.NSType),
|
||||
Port: int64(s.Port),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, &proto.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int64(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int64(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||
if len(zones) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, &proto.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: encodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||
if len(resources) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||
for _, r := range resources {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkResourceRaw{
|
||||
Id: r.AccountSeqID,
|
||||
Name: r.Name,
|
||||
Description: r.Description,
|
||||
Type: string(r.Type),
|
||||
Address: r.Address,
|
||||
DomainValue: r.Domain,
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||
entry.NetworkSeq = seq
|
||||
}
|
||||
if r.Prefix.IsValid() {
|
||||
entry.PrefixCidr = r.Prefix.String()
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||
if len(routersMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||
for networkXID, routers := range routersMap {
|
||||
if len(routers) == 0 {
|
||||
continue
|
||||
}
|
||||
netSeq, ok := e.networkSeq(networkXID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||
for peerID, r := range routers {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkRouterEntry{
|
||||
Id: r.AccountSeqID,
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
entry.PeerIndexSet = true
|
||||
entry.PeerIndex = idx
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||
if len(rpm) == 0 {
|
||||
return nil
|
||||
}
|
||||
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||
// (small slice). Network resources without seq id are dropped, matching how
|
||||
// other components-without-seq are silently filtered.
|
||||
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||
for _, r := range e.components.NetworkResources {
|
||||
if r != nil && r.AccountSeqID != 0 {
|
||||
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||
}
|
||||
}
|
||||
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||
for resourceXID, policies := range rpm {
|
||||
seq, ok := resourceXIDToSeq[resourceXID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(policies)*2)
|
||||
for _, pol := range policies {
|
||||
idxs = append(idxs, policyToIdxs[pol]...)
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||
for groupID, userIDs := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok || len(userIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSetToSlice(s map[string]struct{}) []string {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s))
|
||||
for k := range s {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||
for checkXID, failedPeerIDs := range m {
|
||||
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||
if !ok || seq == 0 {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||
for peerID := range failedPeerIDs {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||
// (which shouldn't happen in production but the encoder is exported)
|
||||
// degrades to login_expiration_enabled = false, which makes
|
||||
// LoginExpired() return false for every peer.
|
||||
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||
if s == nil {
|
||||
return &proto.AccountSettingsCompact{}
|
||||
}
|
||||
return &proto.AccountSettingsCompact{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||
}
|
||||
}
|
||||
|
||||
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
out := &proto.AccountNetwork{
|
||||
Identifier: n.Identifier,
|
||||
NetCidr: n.Net.String(),
|
||||
Dns: n.Dns,
|
||||
Serial: n.CurrentSerial(),
|
||||
}
|
||||
if len(n.NetV6.IP) > 0 {
|
||||
out.NetV6Cidr = n.NetV6.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||
pc := &proto.PeerCompact{
|
||||
WgPubKey: decodeWgKey(p.Key),
|
||||
SshPubKey: []byte(p.SSHKey),
|
||||
DnsLabel: p.DNSLabel,
|
||||
AgentVersionIdx: agentVersionIdx,
|
||||
AddedWithSsoLogin: p.UserID != "",
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
SshEnabled: p.SSHEnabled,
|
||||
SupportsIpv6: p.SupportsIPv6(),
|
||||
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||
}
|
||||
if p.LastLogin != nil {
|
||||
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||
}
|
||||
switch {
|
||||
case !p.IP.IsValid():
|
||||
// leave Ip nil
|
||||
case p.IP.Is4() || p.IP.Is4In6():
|
||||
ip := p.IP.Unmap().As4()
|
||||
pc.Ip = ip[:]
|
||||
default:
|
||||
ip := p.IP.As16()
|
||||
pc.Ip = ip[:]
|
||||
}
|
||||
if p.IPv6.IsValid() {
|
||||
ip := p.IPv6.As16()
|
||||
pc.Ipv6 = ip[:]
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||
// key, or nil for an empty / malformed key.
|
||||
func decodeWgKey(s string) []byte {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, wgKeyRawLen)
|
||||
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||
if err != nil || n != wgKeyRawLen {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portsToUint32(ports []string) []uint32 {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(ports))
|
||||
for _, p := range ports {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, uint32(v))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
out = append(out, &proto.PortInfo_Range{
|
||||
Start: uint32(r.Start),
|
||||
End: uint32(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
|
||||
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||
// input compare byte-equal via proto.Equal.
|
||||
//
|
||||
// This lives on the test side — the encoder itself emits in map-iteration
|
||||
// order. Test-side normalization is the contract for "two encodes are
|
||||
// equivalent".
|
||||
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||
if full == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||
// index 0 by convention.
|
||||
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||
if len(full.AgentVersions) > 0 {
|
||||
// Pair version → original index, sort, rebuild.
|
||||
type avEntry struct {
|
||||
version string
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]avEntry, len(full.AgentVersions))
|
||||
for i, v := range full.AgentVersions {
|
||||
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||
}
|
||||
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||
// keeps the canonicalize output stable when two entries compare
|
||||
// equal (the encoder dedups, but defending against future inputs).
|
||||
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||
if a.version == "" && b.version != "" {
|
||||
return -1
|
||||
}
|
||||
if b.version == "" && a.version != "" {
|
||||
return 1
|
||||
}
|
||||
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||
})
|
||||
newVersions := make([]string, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
avRemap[e.oldIdx] = uint32(newIdx)
|
||||
newVersions[newIdx] = e.version
|
||||
}
|
||||
full.AgentVersions = newVersions
|
||||
}
|
||||
for _, p := range full.Peers {
|
||||
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||
p.AgentVersionIdx = newIdx
|
||||
}
|
||||
}
|
||||
|
||||
type peerEntry struct {
|
||||
peer *proto.PeerCompact
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]peerEntry, len(full.Peers))
|
||||
for i, p := range full.Peers {
|
||||
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||
}
|
||||
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||
// nil from malformed keys, or both empty for placeholders).
|
||||
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||
})
|
||||
|
||||
remap := make(map[uint32]uint32, len(entries))
|
||||
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
remap[e.oldIdx] = uint32(newIdx)
|
||||
newPeers[newIdx] = e.peer
|
||||
}
|
||||
full.Peers = newPeers
|
||||
|
||||
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||
for _, g := range full.Groups {
|
||||
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||
}
|
||||
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, r := range full.Routes {
|
||||
if r.PeerIndexSet {
|
||||
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||
r.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(r.GroupIds)
|
||||
slices.Sort(r.AccessControlGroupIds)
|
||||
slices.Sort(r.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, list := range full.RoutersMap {
|
||||
for _, entry := range list.Entries {
|
||||
if entry.PeerIndexSet {
|
||||
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||
entry.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(entry.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||
}
|
||||
|
||||
for _, set := range full.PostureFailedPeers {
|
||||
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||
}
|
||||
|
||||
for _, p := range full.Policies {
|
||||
slices.Sort(p.SourceGroupIds)
|
||||
slices.Sort(p.DestinationGroupIds)
|
||||
}
|
||||
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||
// a Policy has multiple rules) still get a deterministic order. After
|
||||
// sorting we remap indexes in ResourcePoliciesMap.
|
||||
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||
for i, p := range full.Policies {
|
||||
policyOldOrder[p] = uint32(i)
|
||||
}
|
||||
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||
return c
|
||||
}
|
||||
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||
})
|
||||
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||
for newIdx, p := range full.Policies {
|
||||
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||
}
|
||||
for _, idxs := range full.ResourcePoliciesMap {
|
||||
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||
}
|
||||
for _, list := range full.GroupIdToUserIds {
|
||||
slices.Sort(list.UserIds)
|
||||
}
|
||||
slices.Sort(full.AllowedUserIds)
|
||||
}
|
||||
|
||||
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||
out := make([]uint32, 0, len(idxs))
|
||||
for _, i := range idxs {
|
||||
if newIdx, ok := remap[i]; ok {
|
||||
out = append(out, newIdx)
|
||||
}
|
||||
}
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||
// the encoder is intentionally non-deterministic.
|
||||
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||
canonicalize(a.GetFull())
|
||||
canonicalize(b.GetFull())
|
||||
return goproto.Equal(a, b)
|
||||
}
|
||||
|
||||
func newTestComponents() *types.NetworkMapComponents {
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-a",
|
||||
Key: testWgKeyA,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peera",
|
||||
SSHKey: "ssh-a",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-b",
|
||||
Key: testWgKeyB,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||
DNSLabel: "peerb",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||
}
|
||||
peerC := &nbpeer.Peer{
|
||||
ID: "peer-c",
|
||||
Key: testWgKeyC,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
return &types.NetworkMapComponents{
|
||||
PeerID: "peer-a",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 7,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 2 * time.Hour,
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-a": peerA,
|
||||
"peer-b": peerB,
|
||||
"peer-c": peerC,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "pol-1",
|
||||
AccountSeqID: 10,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||
Ports: []string{"22", "80"},
|
||||
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full, "envelope must contain Full payload")
|
||||
|
||||
assert.EqualValues(t, 7, full.Serial)
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
|
||||
require.NotNil(t, full.Network)
|
||||
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||
|
||||
require.NotNil(t, full.AccountSettings)
|
||||
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
byLabel := map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||
// run produces different wire bytes, but the canonicalized form must
|
||||
// match.
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d must be semantically equivalent to first encode", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, got := range results {
|
||||
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"goroutine %d produced inequivalent envelope", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 2)
|
||||
|
||||
groupByID := map[uint32]*proto.GroupCompact{}
|
||||
for _, g := range full.Groups {
|
||||
groupByID[g.Id] = g
|
||||
}
|
||||
require.Contains(t, groupByID, uint32(1))
|
||||
require.Contains(t, groupByID, uint32(2))
|
||||
assert.Equal(t, "Src", groupByID[1].Name)
|
||||
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.EqualValues(t, 10, pc.Id)
|
||||
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||
assert.True(t, pc.Bidirectional)
|
||||
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||
require.Len(t, pc.PortRanges, 1)
|
||||
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.RouterPeerIndexes, 1)
|
||||
idx := full.RouterPeerIndexes[0]
|
||||
require.Less(t, int(idx), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||
"two distinct versions, order depends on map iteration")
|
||||
|
||||
idxByLabel := map[string]uint32{}
|
||||
for _, p := range full.Peers {
|
||||
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||
}
|
||||
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Policies[0].Enabled = false
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Groups["group-src"].AccountSeqID = 0
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||
// canonicalize identically.
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
|
||||
var byLabel = map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
v6Only := &nbpeer.Peer{
|
||||
ID: "peer-v6",
|
||||
Key: testWgKeyA,
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||
DNSLabel: "peerv6",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.Peers["peer-v6"] = v6Only
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerv6" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||
assert.Len(t, found.Ipv6, 16)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||
ID: "peer-noip",
|
||||
Key: testWgKeyA,
|
||||
DNSLabel: "peernoip",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peernoip" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found)
|
||||
assert.Empty(t, found.Ip)
|
||||
assert.Empty(t, found.Ipv6)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
}
|
||||
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Groups)
|
||||
assert.Empty(t, full.Policies)
|
||||
assert.Empty(t, full.RouterPeerIndexes)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
c.Peers["peer-a"].UserID = "user-1"
|
||||
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||
c.Peers["peer-a"].LastLogin = &now
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var pa *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peera" {
|
||||
pa = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pa)
|
||||
assert.True(t, pa.AddedWithSsoLogin)
|
||||
assert.True(t, pa.LoginExpirationEnabled)
|
||||
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||
|
||||
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||
var pb *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerb" {
|
||||
pb = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pb)
|
||||
assert.False(t, pb.AddedWithSsoLogin)
|
||||
assert.False(t, pb.LoginExpirationEnabled)
|
||||
assert.Zero(t, pb.LastLoginUnixNano)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{
|
||||
{
|
||||
ID: "route-peer",
|
||||
AccountSeqID: 100,
|
||||
NetID: "net-A",
|
||||
Description: "via peer-c",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Peer: "peer-c", // peer ID, not WG key
|
||||
Groups: []string{"group-src"},
|
||||
AccessControlGroups: []string{"group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-peergroup",
|
||||
AccountSeqID: 101,
|
||||
NetID: "net-B",
|
||||
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||
PeerGroups: []string{"group-src", "group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-no-seq",
|
||||
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 3)
|
||||
byNetID := map[string]*proto.RouteRaw{}
|
||||
for _, r := range full.Routes {
|
||||
byNetID[r.NetId] = r
|
||||
}
|
||||
|
||||
r1 := byNetID["net-A"]
|
||||
require.NotNil(t, r1)
|
||||
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||
assert.Empty(t, r1.PeerGroupIds)
|
||||
|
||||
r2 := byNetID["net-B"]
|
||||
require.NotNil(t, r2)
|
||||
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{{
|
||||
ID: "route-x",
|
||||
AccountSeqID: 100,
|
||||
Peer: "peer-not-in-components",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Enabled: true,
|
||||
}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 1)
|
||||
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||
"missing peer reference must not pretend to point at peer index 0")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||
// is the I1 case — without unionPolicies the encoder would silently
|
||||
// drop it from the wire.
|
||||
resourceOnlyPolicy := &types.Policy{
|
||||
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
}
|
||||
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||
}
|
||||
// Resource must appear in components.NetworkResources with a seq id —
|
||||
// encoder uses that to translate the xid map key to uint32.
|
||||
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||
|
||||
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||
policyIdxByID := map[uint32]uint32{}
|
||||
for i, p := range full.Policies {
|
||||
policyByID[p.Id] = p
|
||||
policyIdxByID[p.Id] = uint32(i)
|
||||
}
|
||||
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||
|
||||
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||
require.Len(t, idxs, 2)
|
||||
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||
"resource policies map must reference both wire policy indexes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||
}},
|
||||
Groups: []string{"group-src", "group-not-persisted"},
|
||||
Primary: true, Enabled: true,
|
||||
Domains: []string{"corp.example"},
|
||||
}}
|
||||
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.NameserverGroups, 1)
|
||||
nsg := full.NameserverGroups[0]
|
||||
assert.EqualValues(t, 50, nsg.Id)
|
||||
assert.Equal(t, "Main", nsg.Name)
|
||||
assert.True(t, nsg.Primary)
|
||||
require.Len(t, nsg.Nameservers, 1)
|
||||
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||
"check-1": {
|
||||
"peer-a": {},
|
||||
"peer-b": {},
|
||||
"peer-not-in-account": {},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {
|
||||
"peer-c": {
|
||||
ID: "router-1", AccountSeqID: 200,
|
||||
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
entries := full.RoutersMap[5].Entries
|
||||
require.Len(t, entries, 1)
|
||||
e := entries[0]
|
||||
assert.EqualValues(t, 200, e.Id)
|
||||
assert.True(t, e.PeerIndexSet)
|
||||
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||
assert.True(t, e.Masquerade)
|
||||
assert.EqualValues(t, 10, e.Metric)
|
||||
assert.True(t, e.Enabled)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||
// peer_index reference must still resolve.
|
||||
c := newTestComponents()
|
||||
delete(c.Peers, "peer-c")
|
||||
routerPeer := &nbpeer.Peer{
|
||||
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||
e := full.RoutersMap[5].Entries[0]
|
||||
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.DnsSettings)
|
||||
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.GroupIDToUserIDs = map[string][]string{
|
||||
"group-src": {"user-1", "user-2"},
|
||||
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||
"group-missing": {"user-4"}, // group not in components → drop
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||
}
|
||||
|
||||
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||
}
|
||||
|
||||
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||
nm := &types.NetworkMap{
|
||||
Peers: []*nbpeer.Peer{{
|
||||
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}},
|
||||
FirewallRules: []*types.FirewallRule{{
|
||||
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||
}},
|
||||
}
|
||||
|
||||
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||
|
||||
require.NotNil(t, patch)
|
||||
assert.Len(t, patch.Peers, 1)
|
||||
assert.Len(t, patch.FirewallRules, 1)
|
||||
}
|
||||
|
||||
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||
// pass-through in both encoder branches (normal path + nil-Components
|
||||
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
|
||||
// from one of the envelope struct literals.
|
||||
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||
patch := &proto.ProxyPatch{
|
||||
ForwardingRules: []*proto.ForwardingRule{{
|
||||
Protocol: proto.RuleProtocol_TCP,
|
||||
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run("normal_path", func(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
|
||||
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||
// behaviour for missing/unvalidated peers.
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
// AccountSettings deliberately nil
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||
}
|
||||
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||
// field is intentionally left empty — capable peers ignore it and the
|
||||
// envelope alone is the authoritative wire shape.
|
||||
//
|
||||
// PeerConfig is computed once server-side using the receiving peer's own
|
||||
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||
// client even though the server-side PeerConfig reports false.
|
||||
func ToComponentSyncResponse(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
httpConfig *nbconfig.HttpServerConfig,
|
||||
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||
peer *nbpeer.Peer,
|
||||
turnCredentials *Token,
|
||||
relayCredentials *Token,
|
||||
components *types.NetworkMapComponents,
|
||||
proxyPatch *types.NetworkMap,
|
||||
dnsName string,
|
||||
checks []*posture.Checks,
|
||||
settings *types.Settings,
|
||||
extraSettings *types.ExtraSettings,
|
||||
peerGroups []string,
|
||||
dnsFwdPort int64,
|
||||
) *proto.SyncResponse {
|
||||
network := networkOrZero(components)
|
||||
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: peerConfig,
|
||||
DNSDomain: dnsName,
|
||||
DNSForwarderPort: dnsFwdPort,
|
||||
UserIDClaim: userIDClaim,
|
||||
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||
})
|
||||
|
||||
resp := &proto.SyncResponse{
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMapEnvelope: envelope,
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||
// dereferences network.Net which would panic on nil.
|
||||
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||
if c == nil || c.Network == nil {
|
||||
return &types.Network{}
|
||||
}
|
||||
return c.Network
|
||||
}
|
||||
|
||||
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||
// patch the components envelope ships alongside. Returns nil when there are
|
||||
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||
// sees no patch and skips the merge step entirely.
|
||||
//
|
||||
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||
// delivers fragments pre-expanded — there's no raw component shape to
|
||||
// derive them from. Components purity isn't violated: proxy data isn't
|
||||
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||
// client merges it on top of its locally-computed NetworkMap.
|
||||
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
patch := &proto.ProxyPatch{
|
||||
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||
}
|
||||
if len(nm.ForwardingRules) > 0 {
|
||||
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||
for _, r := range nm.ForwardingRules {
|
||||
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||
}
|
||||
}
|
||||
return patch
|
||||
}
|
||||
|
||||
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||
// without this helper the field would be incorrectly false for any peer
|
||||
// that's the destination of an SSH-enabling policy without having
|
||||
// peer.SSHEnabled set locally.
|
||||
//
|
||||
// Mirrors the two activation paths Calculate() uses:
|
||||
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||
// destinations.
|
||||
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||
//
|
||||
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||
// runs Calculate() over the envelope.
|
||||
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||
if c == nil || peer == nil {
|
||||
return false
|
||||
}
|
||||
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||
// exist in c.Peers, otherwise no rule applies to it.
|
||||
if _, ok := c.Peers[peer.ID]; !ok {
|
||||
return false
|
||||
}
|
||||
for _, policy := range c.Policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||
// peer itself has SSH enabled locally.
|
||||
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||
if rule == nil || !rule.Enabled {
|
||||
return false
|
||||
}
|
||||
if !peerInDestinations(c, rule, peer.ID) {
|
||||
return false
|
||||
}
|
||||
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
return true
|
||||
}
|
||||
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||
// that exactly to avoid silent divergence).
|
||||
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
return rule.DestinationResource.ID == peerID
|
||||
}
|
||||
for _, groupID := range rule.Destinations {
|
||||
if c.IsPeerInGroup(peerID, groupID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||
// the destination peer has SSHEnabled=true locally.
|
||||
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||
const targetPeerID = "target"
|
||||
const targetGroupID = "g_dst"
|
||||
|
||||
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||
return &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||
Groups: map[string]*types.Group{targetGroupID: group},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
}},
|
||||
}, peer
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
peerSSH bool
|
||||
rule types.PolicyRule
|
||||
wantEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22022-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-all-protocol-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-port-range-covers-22",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tcp-80-no-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "disabled-rule-skipped",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-not-in-destinations",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g_other"}, // target not in this group
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-typed-destination-resource-matches",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||
got := computeSSHEnabledForPeer(c, peer)
|
||||
assert.Equal(t, tc.wantEnabled, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||
// belt-and-suspenders presence guard mirroring Calculate's
|
||||
// getAllPeersFromGroups invariant.
|
||||
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||
c := &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||
Groups: map[string]*types.Group{
|
||||
"g": {ID: "g", Peers: []string{"missing"}},
|
||||
},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g"},
|
||||
}},
|
||||
}},
|
||||
}
|
||||
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||
"missing target peer must short-circuit to false, not consult policies")
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||
// components on graceful-degrade paths.
|
||||
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||
}
|
||||
@@ -10,24 +10,20 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -159,8 +155,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
@@ -173,7 +169,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
|
||||
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
|
||||
response.RemotePeers = remotePeers
|
||||
@@ -183,13 +179,13 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
@@ -202,7 +198,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
@@ -242,33 +238,6 @@ func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||
return timestamppb.New(deadline)
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
if nbversion.IsDevelopmentVersion(peerVersion) {
|
||||
return true
|
||||
@@ -282,51 +251,6 @@ func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
@@ -344,203 +268,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
||||
// alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if shouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||
if config == nil || config.AuthAudience == "" {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -62,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -100,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -108,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
@@ -142,7 +142,6 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
|
||||
@@ -1016,7 +1016,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
var plainResp *proto.SyncResponse
|
||||
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||
// computed and recompute the raw components instead. This wastes one
|
||||
// Calculate() call per initial-sync — the component-based wire
|
||||
// format is what the peer actually consumes. The streaming path
|
||||
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||
// because it dispatches by capability before computing.
|
||||
//
|
||||
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
|
||||
// interfaces to return PeerNetworkMapResult so the initial-sync path
|
||||
// stops doing duplicate work. Deferred until the client-side
|
||||
// decoder lands and there's a real deployment of capability=3 peers
|
||||
// worth optimizing for.
|
||||
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||
}
|
||||
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
} else {
|
||||
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
}
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -1636,6 +1636,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range newGroupsToCreate {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
|
||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -3170,6 +3170,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
|
||||
var newJWTGroup *types.Group
|
||||
for _, g := range groups {
|
||||
if g.Name == "group3" {
|
||||
newJWTGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
|
||||
@@ -93,6 +93,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -158,6 +164,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -236,6 +244,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -327,6 +341,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -341,7 +361,6 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
|
||||
var err error
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
|
||||
156
management/server/migration/account_seq.go
Normal file
156
management/server/migration/account_seq.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||
// no-op once everything is consistent.
|
||||
//
|
||||
// Implemented as two table-wide SQL statements with window functions, one
|
||||
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||
// well under a second instead of the per-account-loop ~2 minutes.
|
||||
//
|
||||
// orderColumn is the column to use when assigning the deterministic ordering
|
||||
// (typically the primary-key string id).
|
||||
func BackfillAccountSeqIDs[T any](
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
entity types.AccountSeqEntity,
|
||||
orderColumn string,
|
||||
) error {
|
||||
var model T
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
table := quoteIdent(db, stmt.Schema.Table)
|
||||
orderCol := quoteIdent(db, orderColumn)
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var pending int64
|
||||
if err := tx.Raw(
|
||||
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||
).Scan(&pending).Error; err != nil {
|
||||
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||
}
|
||||
|
||||
if pending > 0 {
|
||||
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||
return fmt.Errorf("rank %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func quoteIdent(db *gorm.DB, name string) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql":
|
||||
return "`" + name + "`"
|
||||
case "postgres":
|
||||
return `"` + name + `"`
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres", "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
WITH max_seq AS (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
),
|
||||
ranked AS (
|
||||
SELECT p.id,
|
||||
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||
FROM %s p
|
||||
JOIN max_seq m ON p.account_id = m.account_id
|
||||
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||
)
|
||||
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||
FROM ranked
|
||||
WHERE %s.id = ranked.id
|
||||
`, table, orderCol, table, table, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
UPDATE %s p
|
||||
JOIN (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
) m ON p.account_id = m.account_id
|
||||
JOIN (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||
) r ON p.id = r.id
|
||||
SET p.account_seq_id = m.max_seq + r.rn
|
||||
`, table, table, orderCol, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql).Error
|
||||
}
|
||||
|
||||
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`, table)
|
||||
case "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql, string(entity)).Error
|
||||
}
|
||||
@@ -67,6 +67,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNSGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -116,6 +122,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -71,9 +72,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
network.ID = xid.New().String()
|
||||
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network seq id: %w", err)
|
||||
}
|
||||
network.AccountSeqID = seq
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
||||
@@ -102,14 +114,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
network.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
|
||||
@@ -255,3 +255,73 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedNetwork)
|
||||
}
|
||||
|
||||
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
|
||||
// non-zero AccountSeqID on the persisted network (allocated through the
|
||||
// account_seq_counters table).
|
||||
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-allocation-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
|
||||
// AccountSeqID even when the caller passes a zero value (the shape REST
|
||||
// handlers produce because the field is `json:"-"`).
|
||||
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-preserve-original",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &types.Network{
|
||||
AccountID: accountID,
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = manager.UpdateNetwork(ctx, userID, update)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -146,6 +146,12 @@ func (m *managerImpl) createResourceInTransaction(ctx context.Context, transacti
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to allocate network resource seq id: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
@@ -245,6 +251,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = oldResource.AccountSeqID
|
||||
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,9 @@ type NetworkResource struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
|
||||
Name string
|
||||
Description string
|
||||
Type NetworkResourceType
|
||||
@@ -93,17 +96,18 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
|
||||
|
||||
func (n *NetworkResource) Copy() *NetworkResource {
|
||||
return &NetworkResource{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -104,6 +105,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", err)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
|
||||
err = transaction.CreateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
@@ -199,6 +206,11 @@ func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction
|
||||
return nil, nil, affectedpeers.Change{}, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
// Preserve AccountSeqID from the existing router so the upstream
|
||||
// UpdateNetworkRouter (which does Updates(router) with Select("*"))
|
||||
// doesn't clobber it with the request's zero value.
|
||||
router.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ type NetworkRouter struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
Masquerade bool
|
||||
@@ -78,14 +81,15 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
|
||||
|
||||
func (n *NetworkRouter) Copy() *NetworkRouter {
|
||||
return &NetworkRouter{
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,12 +7,24 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
|
||||
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the network has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip networks that return false here.
|
||||
func (n *Network) HasSeqID() bool {
|
||||
return n != nil && n.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func NewNetwork(accountId, name, description string) *Network {
|
||||
return &Network{
|
||||
ID: xid.New().String(),
|
||||
@@ -41,13 +53,14 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a copy of a posture checks.
|
||||
// Copy returns a copy of a network.
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
|
||||
// Peer capability constants mirror the proto enum values.
|
||||
const (
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilityComponentNetworkMap int32 = 3
|
||||
)
|
||||
|
||||
// Peer represents a machine connected to the network.
|
||||
@@ -218,6 +219,14 @@ func (p *Peer) SupportsSourcePrefixes() bool {
|
||||
return p.HasCapability(PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
|
||||
// SupportsComponentNetworkMap reports whether the peer assembles its
|
||||
// NetworkMap from server-shipped components instead of consuming a fully
|
||||
// expanded NetworkMap. Determines whether the network_map controller skips
|
||||
// Calculate() server-side and emits the components envelope.
|
||||
func (p *Peer) SupportsComponentNetworkMap() bool {
|
||||
return p.HasCapability(PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func capabilitiesEqual(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
|
||||
@@ -67,10 +67,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
policy.AccountSeqID = existingPolicy.AccountSeqID
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -49,6 +49,10 @@ type Checks struct {
|
||||
// AccountID is a reference to the Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
|
||||
|
||||
// Checks is a set of objects that perform the actual checks
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
@@ -93,6 +97,13 @@ func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Pe
|
||||
return changed
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the posture check has been persisted long enough
|
||||
// to have a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip checks that return false here.
|
||||
func (pc *Checks) HasSeqID() bool {
|
||||
return pc != nil && pc.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
@@ -163,11 +174,12 @@ func (*Checks) TableName() string {
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (pc *Checks) Copy() *Checks {
|
||||
checks := &Checks{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
AccountSeqID: pc.AccountSeqID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -52,7 +53,19 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = seq
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user