mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-28 18:59:57 +00:00
Compare commits
37 Commits
enable-laz
...
client_lif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7706f578fe | ||
|
|
daf5026192 | ||
|
|
ec18b07959 | ||
|
|
9628f016da | ||
|
|
b39e9df194 | ||
|
|
0388e0f262 | ||
|
|
86f896723d | ||
|
|
29ee84999c | ||
|
|
0e8fd22f36 | ||
|
|
ff98105212 | ||
|
|
6465997a69 | ||
|
|
3204270c4b | ||
|
|
6d3bcef2c4 | ||
|
|
5d7cb30e5b | ||
|
|
aff5da2c8e | ||
|
|
9b179be324 | ||
|
|
33e7b6a8f1 | ||
|
|
e0cff5e240 | ||
|
|
0085aebf77 | ||
|
|
91d2d341b7 | ||
|
|
8d46580c13 | ||
|
|
b42fe6a10f | ||
|
|
0f5d7fdc07 | ||
|
|
13c78d98f5 | ||
|
|
d1229ed84c | ||
|
|
9758145517 | ||
|
|
200a5a6a70 | ||
|
|
1f7b1ea863 | ||
|
|
4abb10c1aa | ||
|
|
a45cefe57a | ||
|
|
a6d504633f | ||
|
|
70f2097fff | ||
|
|
befa9a879c | ||
|
|
4152c41796 | ||
|
|
8b76b3d824 | ||
|
|
0503a18644 | ||
|
|
ec6512d660 |
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -59,12 +59,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
|
||||
8
.github/workflows/golang-test-darwin.yml
vendored
8
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,18 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
id: test
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
|
||||
78
.github/workflows/golang-test-linux.yml
vendored
78
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -119,12 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -162,7 +162,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -175,12 +175,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -246,12 +246,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -266,7 +266,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -306,12 +306,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -325,7 +325,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -363,12 +363,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -383,7 +383,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -407,7 +407,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -424,12 +424,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -440,7 +440,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -484,7 +484,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -529,12 +529,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -545,7 +545,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -579,11 +579,10 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -624,12 +623,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -640,7 +639,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -674,13 +673,12 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
@@ -694,12 +692,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -710,7 +708,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -736,7 +734,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
12
.github/workflows/mobile-build-validation.yml
vendored
12
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,13 +28,13 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -54,11 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
38
.github/workflows/release.yml
vendored
38
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -166,12 +166,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
@@ -221,7 +221,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -374,12 +374,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -464,17 +464,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -522,7 +522,7 @@ jobs:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -534,13 +534,13 @@ jobs:
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
14
.github/workflows/test-infrastructure-files.yml
vendored
14
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,17 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/management:latest .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/relay:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -44,11 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
@@ -247,7 +247,7 @@ dockers_v2:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
@@ -295,7 +295,7 @@ dockers_v2:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
@@ -317,7 +317,7 @@ dockers_v2:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
@@ -339,7 +339,7 @@ dockers_v2:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
@@ -361,7 +361,7 @@ dockers_v2:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
@@ -383,7 +383,7 @@ dockers_v2:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
@@ -405,7 +405,7 @@ dockers_v2:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
@@ -462,13 +462,9 @@ checksum:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
@@ -37,11 +37,6 @@
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
> ### 🤖 NetBird Agent Network (Beta)
|
||||
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# NetBird Agent Network
|
||||
|
||||
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||
policy, with no API keys to leak.
|
||||
|
||||
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||
> infrastructure.
|
||||
|
||||
## How it works
|
||||
|
||||
Agent Network is built on two existing NetBird capabilities:
|
||||
|
||||
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||
key server-side, forwards to the API or gateway, and records usage.
|
||||
|
||||
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||
components:
|
||||
|
||||
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||
and runs the per-request middleware pipeline.
|
||||
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||
and usage/access logs.
|
||||
|
||||
## Documentation
|
||||
|
||||
Full documentation, architecture, and quickstart:
|
||||
**https://docs.netbird.io/agent-network**
|
||||
@@ -151,9 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -186,9 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: string(activeProf.ID),
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
|
||||
@@ -138,23 +138,26 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
profileName := args[0]
|
||||
|
||||
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
@@ -163,6 +166,7 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
@@ -326,19 +330,3 @@ func wrapAmbiguityError(err error, handle string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||
// and returns the new profile's ID. It is the single entry point for profile
|
||||
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||
// <name>` auto-create path.
|
||||
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
@@ -261,17 +262,46 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
return prefix + upper
|
||||
}
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
// DialClientGRPCServer returns client connection to the daemon server. It waits
|
||||
// (up to the timeout) for the daemon to become reachable so an `up` issued right
|
||||
// after `service start` tolerates the startup race. Instead of grpc's blocking
|
||||
// dial — whose raw "transport failed" retry warnings are silenced by the logger
|
||||
// config — we drive the wait ourselves and emit one clean line per failed attempt.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
strings.TrimPrefix(addr, "tcp://"),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.Connect()
|
||||
for {
|
||||
state := conn.GetState()
|
||||
if state == connectivity.Ready {
|
||||
return conn, nil
|
||||
}
|
||||
// Log only once the connection has actually failed — not during the
|
||||
// brief Idle/Connecting phase on a healthy daemon (avoids a spurious
|
||||
// line + wait when the daemon is already up).
|
||||
if state == connectivity.TransientFailure {
|
||||
log.Infof("waiting for the netbird daemon to become available at %s...", addr)
|
||||
}
|
||||
// Wake on the next state change, but at least every second so a stuck
|
||||
// TransientFailure re-logs at a steady cadence until the timeout.
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, time.Second)
|
||||
conn.WaitForStateChange(waitCtx, state)
|
||||
waitCancel()
|
||||
if ctx.Err() != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("daemon not reachable at %s: %w", addr, ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackOff execute function in backoff cycle.
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -110,10 +111,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve the active profile's display name via the daemon, which runs
|
||||
// as root and can read the per-user profile files. The local profile
|
||||
// manager only knows the active profile ID, not its display name.
|
||||
profName := getActiveProfileName(ctx)
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
@@ -165,25 +167,6 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getActiveProfileName asks the daemon for the active profile's display
|
||||
// name. The daemon runs as root and can read the per-user profile files to
|
||||
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||
// error so status output degrades gracefully.
|
||||
func getActiveProfileName(ctx context.Context) string {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return resp.GetProfileName()
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "idle", "connecting", "connected":
|
||||
|
||||
@@ -128,9 +128,15 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
profileSwitched = true
|
||||
}
|
||||
|
||||
@@ -145,52 +151,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||
}
|
||||
|
||||
// switchOrCreateProfile switches the active profile to the one identified by
|
||||
// handle, creating it first when it does not exist yet. This restores the
|
||||
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||
// missing profile instead of failing.
|
||||
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok || st.Code() != codes.NotFound {
|
||||
return err
|
||||
}
|
||||
// Don't fail immediately on a create error: a concurrent run may
|
||||
// have created the profile between the NotFound above and this
|
||||
// call, in which case the retried switch still succeeds. Only
|
||||
// surface the create error if the switch also fails.
|
||||
_, createErr := createProfile(ctx, handle, username)
|
||||
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||
if createErr != nil {
|
||||
return fmt.Errorf("create profile: %w", createErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProfile dials the daemon and creates a new profile with the given
|
||||
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||
// when a daemon client is already available to reuse the connection.
|
||||
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||
// override the default profile filepath if provided
|
||||
if configPath != "" {
|
||||
@@ -241,10 +201,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
return connectClient.Run(config, nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
|
||||
@@ -264,34 +264,24 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||
client := internal.NewConnectClient(ctx, c.recorder)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// The supervisor owns the run; we wait until it is established, ends with a
|
||||
// startup error (permanent backoff err), or startCtx expires.
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
client.RunAsync(c.config, nil)
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||
// run loop to tear the engine down, so this cancel() is no longer
|
||||
// required to break the deadlock and could be removed. It is kept as a
|
||||
// defensive belt-and-suspenders: cancelling the parent context first
|
||||
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||
if err := client.WaitEstablishedOrDone(startCtx); err != nil {
|
||||
// Either startCtx expired while connecting, or the run ended before it
|
||||
// established. Cancel the client context before stopping: Engine.Start
|
||||
// blocks on the signal stream while holding the engine mutex and only
|
||||
// unblocks on cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
return fmt.Errorf("stop error after startup failure. Stop error: %w. Startup: %w", stopErr, err)
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -49,17 +49,23 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
// androidMobileDep is set on Android to inject the MobileDependency for runs
|
||||
// started through the generic entry points (Run/RunAsync, e.g. embed.Client).
|
||||
// nil on other platforms, where the dependency is empty.
|
||||
var androidMobileDep func(config *profilemanager.Config) MobileDependency
|
||||
|
||||
// mobileDependency returns the MobileDependency for a run started via the
|
||||
// generic entry points. On Android the androidMobileDep provider supplies
|
||||
// platform stubs (or real implementations); elsewhere it is empty.
|
||||
func (c *ConnectClient) mobileDependency(config *profilemanager.Config) MobileDependency {
|
||||
if androidMobileDep != nil {
|
||||
return androidMobileDep(config)
|
||||
}
|
||||
return MobileDependency{}
|
||||
}
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
runCancel context.CancelFunc
|
||||
runExited chan struct{}
|
||||
runOnce sync.Once
|
||||
runStarted atomic.Bool
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
engine *Engine
|
||||
@@ -68,41 +74,62 @@ type ConnectClient struct {
|
||||
updateManager *updater.Manager
|
||||
|
||||
persistSyncResponse bool
|
||||
|
||||
// sup serializes all start/stop requests so two lifecycle operations can
|
||||
// never overlap. See connect_lifecycle.go.
|
||||
sup *supervisor
|
||||
}
|
||||
|
||||
func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
) *ConnectClient {
|
||||
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||
// loop. runCancel is set once at construction, so Stop can call it without
|
||||
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
return &ConnectClient{
|
||||
ctx: runCtx,
|
||||
runCancel: runCancel,
|
||||
runExited: make(chan struct{}),
|
||||
config: config,
|
||||
c := &ConnectClient{
|
||||
ctx: ctx,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
c.sup = newSupervisor(ctx, c.run)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
// Run with main logic. md carries optional gRPC metadata (e.g. the UI
|
||||
// user-agent) to forward to the management/signal services; nil when none.
|
||||
func (c *ConnectClient) Run(config *profilemanager.Config, md metadata.MD, logPath string) error {
|
||||
return c.sup.start(config, md, c.mobileDependency(config), logPath)
|
||||
}
|
||||
|
||||
// RunAsync starts a client run without blocking. Used by the daemon and embed,
|
||||
// which drive the lifecycle through the supervisor rather than blocking on Run;
|
||||
// they then wait for the outcome via WaitEstablishedOrDone. The run's lifecycle
|
||||
// channels are created and owned by the supervisor — callers never hold them.
|
||||
func (c *ConnectClient) RunAsync(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.startAsync(config, md, c.mobileDependency(config), "", nil)
|
||||
}
|
||||
|
||||
// Restart atomically stops any in-flight run and starts a fresh one with the
|
||||
// given config. The stop+start happens as a single supervisor operation, so no
|
||||
// other lifecycle request can interleave between them — used for explicit
|
||||
// restarts (e.g. an MDM policy change) that must not expose a "stopped" window.
|
||||
func (c *ConnectClient) Restart(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.restartAsync(config, md, c.mobileDependency(config), "")
|
||||
}
|
||||
|
||||
// WaitEstablishedOrDone blocks until the in-flight run becomes established (nil),
|
||||
// ends before that (the run error, or a sentinel on a clean stop), or ctx is
|
||||
// cancelled. Returns errNoRunInFlight if no run is in flight. Wraps the wait on
|
||||
// the supervisor-owned channels so callers never touch them directly.
|
||||
func (c *ConnectClient) WaitEstablishedOrDone(ctx context.Context) error {
|
||||
return c.sup.waitEstablishedOrDone(ctx)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
config *profilemanager.Config,
|
||||
tunAdapter device.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
@@ -121,10 +148,11 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
return c.sup.start(config, nil, mobileDependency, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
config *profilemanager.Config,
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
@@ -142,15 +170,12 @@ func (c *ConnectClient) RunOniOS(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, logFilePath)
|
||||
return c.sup.start(config, nil, mobileDependency, logFilePath)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||
// the loop to finish (and skip the wait if the loop never ran).
|
||||
c.runStarted.Store(true)
|
||||
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||
|
||||
// run executes a single client run. runCtx is owned by the supervisor: cancelling
|
||||
// it tears the run down (it is the parent of the per-attempt engine context).
|
||||
func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Config, mobileDependency MobileDependency, connEstablishedChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -214,18 +239,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}()
|
||||
|
||||
wrapErr := state.Wrap
|
||||
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if c.config.ManagementURL.Scheme == "https" {
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -259,13 +284,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.ctx.Err() != nil {
|
||||
if runCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.Set(StatusConnecting)
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
engineCtx, cancel := context.WithCancel(runCtx)
|
||||
defer func() {
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
@@ -273,8 +298,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
@@ -291,7 +316,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
defer func() {
|
||||
if err = mgmClient.Close(); err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
@@ -300,13 +325,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginStarted := time.Now()
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, config)
|
||||
if err != nil {
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
c.runCancel()
|
||||
// No teardown needed: login fails before the engine is started
|
||||
// (engine.Start is below), so there is nothing running to stop.
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -360,7 +386,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -404,7 +430,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
@@ -412,12 +438,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
// The supervisor owns connEstablishedChan and it is always present. Guard
|
||||
// against a double close: operation re-runs on ErrResetConnection retries
|
||||
// within the same run, and the channel is closed only on the first connect.
|
||||
select {
|
||||
case <-connEstablishedChan:
|
||||
default:
|
||||
close(connEstablishedChan)
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -426,8 +453,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||
|
||||
// Always tear the engine down once its context is cancelled. engine.Stop
|
||||
// is nil-guarded per component, so calling it unconditionally is safe and
|
||||
// avoids both the data race on engine.wgInterface and skipping teardown
|
||||
// when the interface was never brought up (e.g. a mid-start failure).
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
@@ -445,12 +474,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// Login failed permanently: the engine was never started, so there
|
||||
// is nothing to tear down — just record that a login is needed.
|
||||
state.Set(StatusNeedsLogin)
|
||||
c.runCancel()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -471,6 +501,22 @@ func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
return relayCfg.GetUrls(), token
|
||||
}
|
||||
|
||||
// ConnectionRunning reports whether a connection run is currently in flight
|
||||
// (connecting, connected, or reconnecting). Answered by the supervisor via a
|
||||
// serialized query, so it settles behind an in-flight stop. Distinct from
|
||||
// ServiceRunning, which reports whether the service itself is alive.
|
||||
func (c *ConnectClient) ConnectionRunning() bool {
|
||||
return c.sup.isRunning()
|
||||
}
|
||||
|
||||
// ServiceRunning reports whether the client's lifecycle supervisor is alive and
|
||||
// able to accept start/stop commands — i.e. its context has not been cancelled
|
||||
// (the daemon is not shutting down). Independent of whether a connection run is
|
||||
// up (that is ConnectionRunning).
|
||||
func (c *ConnectClient) ServiceRunning() bool {
|
||||
return c.sup.ctx.Err() == nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
if c == nil {
|
||||
return nil
|
||||
@@ -527,12 +573,10 @@ func (c *ConnectClient) Status() StatusType {
|
||||
return status
|
||||
}
|
||||
|
||||
// Stop serializes a stop request through the lifecycle supervisor and blocks
|
||||
// until the in-flight run is fully torn down.
|
||||
func (c *ConnectClient) Stop() error {
|
||||
c.runCancel()
|
||||
if c.runStarted.Load() {
|
||||
<-c.runExited
|
||||
}
|
||||
return nil
|
||||
return c.sup.stop()
|
||||
}
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
@@ -59,19 +60,17 @@ var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// Wire up the default MobileDependency provider so embed.Client.Start() works
|
||||
// on Android with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
// Applications that need P2P ICE or real DNS should replace this by setting
|
||||
// androidMobileDep before calling Start().
|
||||
androidMobileDep = func(config *profilemanager.Config) MobileDependency {
|
||||
return mobileDependencyForEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,23 +10,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
// mobileDependencyForEmbed builds the MobileDependency used by embed.Client on
|
||||
// Android so the engine's existing Android code paths work unchanged.
|
||||
func mobileDependencyForEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
) MobileDependency {
|
||||
return MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
|
||||
362
client/internal/connect_lifecycle.go
Normal file
362
client/internal/connect_lifecycle.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
// errAlreadyRunning is returned when a start is requested while a run is already
|
||||
// in flight.
|
||||
var errAlreadyRunning = errors.New("client is already running")
|
||||
|
||||
// errNoRunInFlight is returned by waitEstablishedOrDone when no run is active.
|
||||
var errNoRunInFlight = errors.New("no connection run in flight")
|
||||
|
||||
// errStoppedBeforeEstablished is returned when a run ended (cleanly) before the
|
||||
// connection was established.
|
||||
var errStoppedBeforeEstablished = errors.New("run stopped before the connection was established")
|
||||
|
||||
// lifecycleOp is a serialized lifecycle operation processed by the supervisor.
|
||||
type lifecycleOp int
|
||||
|
||||
const (
|
||||
opStart lifecycleOp = iota
|
||||
opStop
|
||||
opRestart
|
||||
opStatus
|
||||
opWaitEstablished
|
||||
)
|
||||
|
||||
// lifecycleCmd is a single lifecycle request handed to the supervisor goroutine.
|
||||
// They all flow through the same cmdCh so they are strictly ordered (FIFO) with
|
||||
// respect to each other.
|
||||
type lifecycleCmd struct {
|
||||
op lifecycleOp
|
||||
config *profilemanager.Config
|
||||
md metadata.MD
|
||||
mobileDep MobileDependency
|
||||
logPath string
|
||||
|
||||
// done is the caller's notification channel (nil for fire-and-forget). Its
|
||||
// meaning depends on op:
|
||||
// - opStart: receives the run's end result when the run terminates, or
|
||||
// errAlreadyRunning immediately if a run is already in flight.
|
||||
// - opStop: receives nil once the in-flight run has fully unwound.
|
||||
// - opWaitEstablished: receives the wait outcome (see waitEstablishedOrDone).
|
||||
done chan error
|
||||
|
||||
reply chan bool // opStatus only: receives whether a run is in flight
|
||||
waitCtx context.Context // opWaitEstablished only: the waiter's cancellation context
|
||||
}
|
||||
|
||||
// runState holds the lifecycle channels of a single in-flight run, owned by the
|
||||
// loop goroutine. It never escapes the supervisor as an API; the only readers
|
||||
// are the per-wait goroutines the loop spawns for opWaitEstablished.
|
||||
//
|
||||
// connEstablishedChan is closed by the run once the connection is established.
|
||||
// The supervisor creates and owns it — callers no longer supply it; they observe
|
||||
// it through waitEstablishedOrDone. ended is closed (broadcast) when the run
|
||||
// terminates, so any number of waiters can observe it; err is the run's end
|
||||
// result, valid only after ended is closed.
|
||||
type runState struct {
|
||||
connEstablishedChan chan struct{} // closed by the run on established
|
||||
ended chan struct{} // closed by finishRun when the run terminates
|
||||
err error // run end result, valid after ended is closed
|
||||
}
|
||||
|
||||
// runEndResult is sent by the run goroutine to the supervisor when a run ends,
|
||||
// whether on its own (error / external context cancellation) or because of a Stop.
|
||||
type runEndResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// runFunc executes a single client run bound to the supervisor-owned context,
|
||||
// with the config supplied by the start request.
|
||||
type runFunc func(ctx context.Context, config *profilemanager.Config, mobileDep MobileDependency, connEstablishedChan chan struct{}, logPath string) error
|
||||
|
||||
// supervisor serializes start/stop of a single client run. Every request goes
|
||||
// through cmdCh and is handled one at a time by the loop goroutine, so two
|
||||
// lifecycle operations can never overlap and their order is preserved (FIFO).
|
||||
// The loop goroutine is the sole owner of curStart/runCancel, so that state
|
||||
// needs no locking. The loop exits when the parent context is cancelled.
|
||||
type supervisor struct {
|
||||
ctx context.Context
|
||||
run runFunc
|
||||
cmdCh chan lifecycleCmd
|
||||
runEnded chan runEndResult
|
||||
|
||||
// owned exclusively by the loop goroutine. curStart is the in-flight start
|
||||
// command (nil = idle); its done channel is notified when the run ends.
|
||||
// curRun holds that run's lifecycle channels; runCancel cancels it.
|
||||
curStart *lifecycleCmd
|
||||
curRun *runState
|
||||
runCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newSupervisor(ctx context.Context, run runFunc) *supervisor {
|
||||
s := &supervisor{
|
||||
ctx: ctx,
|
||||
run: run,
|
||||
cmdCh: make(chan lifecycleCmd, 16),
|
||||
runEnded: make(chan runEndResult, 1),
|
||||
}
|
||||
go s.loop()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *supervisor) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.shutdown()
|
||||
return
|
||||
case cmd := <-s.cmdCh:
|
||||
switch cmd.op {
|
||||
case opStart:
|
||||
s.handleStart(cmd)
|
||||
case opStop:
|
||||
s.handleStop(cmd)
|
||||
case opRestart:
|
||||
s.handleRestart(cmd)
|
||||
case opStatus:
|
||||
cmd.reply <- (s.isRunningInternal())
|
||||
case opWaitEstablished:
|
||||
s.handleWaitEstablished(cmd)
|
||||
}
|
||||
case res := <-s.runEnded:
|
||||
// Run ended on its own, without an explicit Stop.
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
notify(cmd.done, errAlreadyRunning)
|
||||
return
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(s.ctx)
|
||||
if cmd.md != nil {
|
||||
// Carry caller-supplied gRPC metadata (e.g. UI user-agent) into the run
|
||||
// context so the engine's management/signal calls forward it. The cancel
|
||||
// still drives runCtx (metadata wrapping preserves cancellation).
|
||||
runCtx = metadata.NewOutgoingContext(runCtx, cmd.md)
|
||||
}
|
||||
s.runCancel = cancel
|
||||
s.curStart = &cmd
|
||||
s.curRun = &runState{connEstablishedChan: make(chan struct{}), ended: make(chan struct{})}
|
||||
|
||||
go func(ctx context.Context, cfg *profilemanager.Config, m MobileDependency, established chan struct{}, lp string) {
|
||||
err := s.run(ctx, cfg, m, established, lp)
|
||||
s.runEnded <- runEndResult{err: err}
|
||||
}(runCtx, cmd.config, cmd.mobileDep, s.curRun.connEstablishedChan, cmd.logPath)
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStop(cmd lifecycleCmd) {
|
||||
if !s.isRunningInternal() {
|
||||
notify(cmd.done, nil)
|
||||
return
|
||||
}
|
||||
s.stopCurrentRun()
|
||||
notify(cmd.done, nil)
|
||||
}
|
||||
|
||||
// handleRestart tears down any in-flight run and starts a fresh one in a single
|
||||
// loop turn. No other command can interleave between the stop and the start
|
||||
// (the loop is single-threaded), so the swap is atomic without relying on any
|
||||
// daemon-side lock — that is what an explicit restart (e.g. MDM config change)
|
||||
// needs to avoid a window where the client is observably stopped.
|
||||
func (s *supervisor) handleRestart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
s.stopCurrentRun()
|
||||
}
|
||||
s.handleStart(cmd)
|
||||
}
|
||||
|
||||
// stopCurrentRun cancels the in-flight run and blocks the supervisor until it
|
||||
// has fully unwound, so the next action starts from a clean slate. The run
|
||||
// goroutine reports completion via runEnded. Caller must hold an in-flight run
|
||||
// (curStart != nil).
|
||||
func (s *supervisor) stopCurrentRun() {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
|
||||
// finishRun resets lifecycle state after a run terminates and hands the run
|
||||
// error back to whoever asked to be notified of the start.
|
||||
func (s *supervisor) finishRun(err error) {
|
||||
s.runCancel = nil
|
||||
if s.isRunningInternal() {
|
||||
// Publish the result to the broadcast channel before nil-ing curRun, so
|
||||
// any opWaitEstablished goroutines blocked on ended observe err.
|
||||
s.curRun.err = err
|
||||
close(s.curRun.ended)
|
||||
s.curRun = nil
|
||||
|
||||
notify(s.curStart.done, err)
|
||||
s.curStart = nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitEstablished answers an opWaitEstablished request. The select itself
|
||||
// runs in a spawned goroutine on the run's channels so it never blocks the loop;
|
||||
// the loop only snapshots the in-flight run's channels (which it owns) here.
|
||||
func (s *supervisor) handleWaitEstablished(cmd lifecycleCmd) {
|
||||
caller := cmd.done
|
||||
if !s.isRunningInternal() {
|
||||
notify(caller, errNoRunInFlight)
|
||||
return
|
||||
}
|
||||
rs := s.curRun
|
||||
established := rs.connEstablishedChan
|
||||
ctx := cmd.waitCtx
|
||||
go func() {
|
||||
select {
|
||||
case <-established:
|
||||
notify(caller, nil)
|
||||
case <-rs.ended:
|
||||
if rs.err != nil {
|
||||
notify(caller, rs.err)
|
||||
return
|
||||
}
|
||||
notify(caller, errStoppedBeforeEstablished)
|
||||
case <-ctx.Done():
|
||||
notify(caller, ctx.Err())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown tears down the in-flight run when the parent context is cancelled,
|
||||
// then fails any still-queued commands so their callers never hang.
|
||||
func (s *supervisor) shutdown() {
|
||||
if s.runCancel != nil {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case cmd := <-s.cmdCh:
|
||||
notify(cmd.done, s.ctx.Err())
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAsync enqueues a start without blocking. If done is non-nil it receives
|
||||
// the run's end result (or errAlreadyRunning on rejection, or the context error
|
||||
// on shutdown).
|
||||
func (s *supervisor) startAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string, done chan error) {
|
||||
cmd := lifecycleCmd{op: opStart, config: config, md: md, mobileDep: mobileDep, logPath: logPath, done: done}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
notify(done, s.ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// restartAsync enqueues an atomic stop+start without blocking. The supervisor
|
||||
// tears down any in-flight run and starts a fresh one with the supplied config
|
||||
// in a single loop turn (see handleRestart). Fire-and-forget: the new run owns
|
||||
// its lifecycle channels, observed via waitEstablishedOrDone.
|
||||
func (s *supervisor) restartAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) {
|
||||
cmd := lifecycleCmd{op: opRestart, config: config, md: md, mobileDep: mobileDep, logPath: logPath}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// start enqueues a start and blocks until the run terminates, preserving the
|
||||
// blocking contract of the legacy Run entry points.
|
||||
func (s *supervisor) start(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) error {
|
||||
done := make(chan error, 1)
|
||||
s.startAsync(config, md, mobileDep, logPath, done)
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// isRunning asks the loop whether a run is in flight. The query is serialized
|
||||
// with start/stop, so during a stop it waits for the teardown to settle and
|
||||
// then reports the final state — never a transient "half-stopped".
|
||||
func (s *supervisor) isRunning() bool {
|
||||
reply := make(chan bool, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStatus, reply: reply}:
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
return r
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) isRunningInternal() bool {
|
||||
return s.curStart != nil
|
||||
}
|
||||
|
||||
// waitEstablishedOrDone blocks until the in-flight run becomes established
|
||||
// (returns nil) or ends before that (returns the run error, or
|
||||
// errStoppedBeforeEstablished on a clean stop), or ctx is cancelled. Returns
|
||||
// errNoRunInFlight if no run is in flight. The wait is performed by a goroutine
|
||||
// spawned inside the loop (see handleWaitEstablished); the run's channels never
|
||||
// leave the supervisor.
|
||||
func (s *supervisor) waitEstablishedOrDone(ctx context.Context) error {
|
||||
reply := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opWaitEstablished, waitCtx: ctx, done: reply}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-reply:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// stop enqueues a stop and blocks until the in-flight run is fully torn down.
|
||||
func (s *supervisor) stop() error {
|
||||
done := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStop, done: done}:
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends on a caller-supplied channel without blocking. The channel is
|
||||
// expected to be buffered (cap 1); a nil channel means the caller did not ask
|
||||
// to be notified.
|
||||
func notify(ch chan error, err error) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -51,20 +51,13 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -83,10 +76,9 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -207,35 +207,3 @@ func FormatAnswers(answers []dns.RR) string {
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||
// advertise EDNS0.
|
||||
func StripOPT(msg *dns.Msg) {
|
||||
if len(msg.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := msg.Extra[:0]
|
||||
for _, rr := range msg.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
msg.Extra = out
|
||||
}
|
||||
|
||||
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||
// the message, if present.
|
||||
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||
opt := msg.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||
return ede, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -120,42 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
StripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestExtractEDE(t *testing.T) {
|
||||
t.Run("no edns", func(t *testing.T) {
|
||||
_, ok := ExtractEDE(&dns.Msg{})
|
||||
assert.False(t, ok, "message without OPT has no EDE")
|
||||
})
|
||||
|
||||
t.Run("edns without ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
rm.SetEdns0(4096, false)
|
||||
_, ok := ExtractEDE(rm)
|
||||
assert.False(t, ok, "OPT without EDE option returns false")
|
||||
})
|
||||
|
||||
t.Run("with ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||
rm.Extra = append(rm.Extra, opt)
|
||||
|
||||
ede, ok := ExtractEDE(rm)
|
||||
assert.True(t, ok, "EDE option should be found")
|
||||
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -39,15 +38,11 @@ const (
|
||||
// defaultWarningDelayBase is the starting grace window before a
|
||||
// "Nameserver group unreachable" event fires for a group that's
|
||||
// never been healthy and only has overlay upstreams with no
|
||||
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||
// see warningDelay.
|
||||
defaultWarningDelayBase = 60 * time.Second
|
||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
||||
defaultWarningDelayBase = 30 * time.Second
|
||||
// warningDelayBonusCap caps the route-count bonus added to the
|
||||
// base grace window. See warningDelay.
|
||||
// base grace window. See warningDelayFor.
|
||||
warningDelayBonusCap = 30 * time.Second
|
||||
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||
)
|
||||
|
||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||
@@ -140,7 +135,7 @@ type DefaultServer struct {
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxHandlers []handlerWrapper
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@@ -204,6 +199,8 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||
type DefaultServerConfig struct {
|
||||
WgInterface WGIface
|
||||
@@ -292,6 +289,7 @@ func newDefaultServer(
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -300,7 +298,7 @@ func newDefaultServer(
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
warningDelayBase: warningDelayBaseFromEnv(),
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
@@ -330,7 +328,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
type routeSettable interface {
|
||||
setSelectedRoutes(func() route.HAMap)
|
||||
}
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
if h, ok := entry.handler.(routeSettable); ok {
|
||||
h.setSelectedRoutes(selected)
|
||||
}
|
||||
@@ -980,23 +978,19 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxHandlers {
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
// The local resolver is a persistent singleton shared by every custom
|
||||
// zone and reused across config updates. Its chain registrations are
|
||||
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||
// context (breaking external CNAME-target resolution) and clears its
|
||||
// records, so it must not be torn down here.
|
||||
if existing.handler != s.localResolver {
|
||||
existing.handler.Stop()
|
||||
}
|
||||
existing.handler.Stop()
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
s.dnsMuxHandlers = muxUpdates
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||
@@ -1160,26 +1154,6 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
||||
return false
|
||||
}
|
||||
|
||||
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||
// non-positive values fall back to defaultWarningDelayBase.
|
||||
func warningDelayBaseFromEnv() time.Duration {
|
||||
val := os.Getenv(envWarningDelay)
|
||||
if val == "" {
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
if d <= 0 {
|
||||
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// warningDelay returns the grace window for the given selected-route
|
||||
// count. Scales gently: +1s per 100 routes, capped by
|
||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||
@@ -1230,7 +1204,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
||||
// in more than one handler.
|
||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||
if !ok {
|
||||
continue
|
||||
|
||||
@@ -104,6 +104,19 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -119,20 +132,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -154,17 +169,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
@@ -173,10 +191,10 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
@@ -197,13 +215,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
@@ -212,7 +232,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
@@ -220,7 +240,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -242,7 +262,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -264,7 +284,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -281,41 +301,42 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
@@ -372,7 +393,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
@@ -384,20 +405,14 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,8 +512,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
@@ -1014,15 +1029,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := []handlerWrapper{
|
||||
{
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
@@ -1031,15 +1046,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := []handlerWrapper{
|
||||
{
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
"upstream-root2": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
@@ -1048,22 +1063,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := []handlerWrapper{
|
||||
{
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
},
|
||||
{
|
||||
"upstream-other": {
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
@@ -1074,7 +1089,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers []handlerWrapper
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
@@ -1358,38 +1373,32 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxHandlers: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
var found *handlerWrapper
|
||||
for i := range server.dnsMuxHandlers {
|
||||
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||
found = &server.dnsMuxHandlers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||
if found != nil {
|
||||
assert.Equal(t, expectedDomain, found.domain,
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for _, entry := range server.dnsMuxHandlers {
|
||||
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@@ -1404,7 +1413,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxHandlers {
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
@@ -1413,108 +1422,12 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxHandlers")
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||
// for the given fqdn pattern at the given priority.
|
||||
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||
for _, h := range s.handlerChain.handlers {
|
||||
if h.OrigPattern == pattern && h.Priority == priority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||
// tracks each (handler, domain) registration independently when one handler
|
||||
// serves multiple zones. Every custom zone is served by the same handler
|
||||
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||
// zones onto one entry, leaving removed zones in the chain to answer
|
||||
// authoritatively with no records.
|
||||
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||
// One handler serves every custom zone, mirroring s.localResolver.
|
||||
shared := &mockHandler{Id: "local-resolver"}
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Two custom zones under the same handler. The surviving zone is registered
|
||||
// last, mirroring the management emission order.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test should be registered after the first update")
|
||||
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should be registered after the first update")
|
||||
|
||||
// Remove one zone, keep the other.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should remain after removing userzone.test")
|
||||
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||
// does not tear down the shared local resolver during reconfiguration. The
|
||||
// resolver is a process-lifetime singleton reused across config updates;
|
||||
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||
// resolution) and clears its records. updateMux must deregister its chain
|
||||
// entries without stopping it. Records surviving a teardown update is the
|
||||
// observable proxy: Stop() would have cleared them.
|
||||
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||
resolver := local.NewResolver()
|
||||
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||
Name: "peer.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}))
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
localResolver: resolver,
|
||||
}
|
||||
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
// Remove the zone. The resolver must survive so its records and lookup
|
||||
// context stay intact for the next registration.
|
||||
server.updateMux(nil)
|
||||
|
||||
var response *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
response = m
|
||||
return nil
|
||||
},
|
||||
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||
|
||||
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2136,6 +2049,7 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
@@ -2293,7 +2207,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
@@ -2369,11 +2283,12 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
@@ -2480,6 +2395,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
@@ -2491,7 +2407,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2528,6 +2444,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
@@ -2542,7 +2459,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2567,32 +2484,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set bool
|
||||
val string
|
||||
want time.Duration
|
||||
}{
|
||||
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envWarningDelay, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(envWarningDelay)
|
||||
}
|
||||
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
@@ -2704,6 +2595,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
@@ -2721,7 +2613,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2748,6 +2640,7 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -443,32 +443,29 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||
u.markUpstreamOk(upstream)
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||
// refused zones, transient recursion errors), not reachability
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(rm)
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(rm)
|
||||
stripOPT(rm)
|
||||
}
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
@@ -523,6 +520,22 @@ func upstreamUDPSize() uint16 {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
|
||||
@@ -517,78 +517,6 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||
// those are per-question outcomes from a reachable server and must not mark
|
||||
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respA mockUpstreamResponse
|
||||
respB mockUpstreamResponse
|
||||
wantHealthy bool
|
||||
}{
|
||||
{
|
||||
name: "both SERVFAIL are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "both REFUSED are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "timeout marks unhealthy",
|
||||
respA: mockUpstreamResponse{err: timeoutErr},
|
||||
respB: mockUpstreamResponse{err: timeoutErr},
|
||||
wantHealthy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): tc.respA,
|
||||
b.String(): tc.respB,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a, b})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||
if tc.wantHealthy {
|
||||
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||
} else {
|
||||
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -985,6 +913,19 @@ func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
@@ -26,15 +26,6 @@ import (
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||
// client can see the reason without inspecting this peer's logs. They live in
|
||||
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||
const (
|
||||
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||
edeNetbirdUpstreamFailure uint16 = 49153
|
||||
)
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
@@ -229,7 +220,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -342,7 +333,6 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
@@ -384,10 +374,6 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
@@ -428,33 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return edeNetbirdUpstreamTimeout
|
||||
}
|
||||
return edeNetbirdUpstreamFailure
|
||||
}
|
||||
|
||||
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||
// It deliberately omits the upstream server address, which may be an internal
|
||||
// resolver and is exposed to any client permitted to use the route; the full
|
||||
// detail stays in the forwarder's local log.
|
||||
func edeText(dnsErr *net.DNSError) string {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return "netbird forwarder: upstream timeout"
|
||||
}
|
||||
return "netbird forwarder: upstream failure"
|
||||
}
|
||||
|
||||
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||
// creating the OPT pseudo-record if the response does not already carry one.
|
||||
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||
opt := resp.IsEdns0()
|
||||
if opt == nil {
|
||||
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||
opt = resp.IsEdns0()
|
||||
}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -618,85 +617,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lookupErr error
|
||||
reqEdns bool
|
||||
wantEDE bool
|
||||
wantCode uint16
|
||||
wantTextHas string
|
||||
}{
|
||||
{
|
||||
name: "timeout with edns0",
|
||||
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamTimeout,
|
||||
wantTextHas: "netbird forwarder: upstream timeout",
|
||||
},
|
||||
{
|
||||
name: "server failure with edns0",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamFailure,
|
||||
wantTextHas: "netbird forwarder: upstream failure",
|
||||
},
|
||||
{
|
||||
name: "no edns0 in request omits ede",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: false,
|
||||
wantEDE: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
if tt.reqEdns {
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
mockResolver.AssertExpectations(t)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||
|
||||
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||
if !tt.wantEDE {
|
||||
assert.False(t, ok, "response must not carry EDE")
|
||||
return
|
||||
}
|
||||
require.True(t, ok, "response must carry EDE")
|
||||
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -86,8 +88,6 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -201,8 +201,6 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
started bool
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
@@ -283,15 +281,9 @@ func NewEngine(
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
// The engine is single-use: a fresh instance is built per connection
|
||||
// cycle (see Client.run), so the run context is created once here rather
|
||||
// than in Start.
|
||||
ctx, cancel := context.WithCancel(clientCtx)
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
@@ -324,34 +316,8 @@ func (e *Engine) Stop() error {
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.cancel()
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
e.stopLocked()
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked tears down everything Start may have brought up, in the order
|
||||
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||
func (e *Engine) stopLocked() {
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
@@ -402,6 +368,10 @@ func (e *Engine) stopLocked() {
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
@@ -420,6 +390,21 @@ func (e *Engine) stopLocked() {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
@@ -457,38 +442,18 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// The engine is single-use. Reject a duplicate start and a start on an
|
||||
// already-stopped engine (run context cancelled).
|
||||
if e.started {
|
||||
return ErrEngineAlreadyStarted
|
||||
}
|
||||
|
||||
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||
}
|
||||
|
||||
e.started = true
|
||||
|
||||
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||
// run context first so goroutines started before the failure (connMgr,
|
||||
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||
// not just what close() covers.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.cancel()
|
||||
e.stopLocked()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
@@ -522,11 +487,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -561,6 +528,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -569,6 +537,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -580,6 +549,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -604,7 +574,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
if err := e.dnsServer.Initialize(); err != nil {
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -616,9 +588,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
if err = e.receiveSignalEvents(); err != nil {
|
||||
return err
|
||||
}
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
@@ -670,6 +640,7 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
@@ -1066,7 +1037,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
@@ -1097,20 +1068,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||
// can be excluded from the reported network addresses; the interface coming and
|
||||
// going otherwise churns the peer meta on the management server.
|
||||
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||
var ips []netip.Addr
|
||||
if e.config.WgAddr.IP.IsValid() {
|
||||
ips = append(ips, e.config.WgAddr.IP)
|
||||
}
|
||||
if e.config.WgAddr.HasIPv6() {
|
||||
ips = append(ips, e.config.WgAddr.IPv6)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -1170,6 +1127,20 @@ func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
}
|
||||
|
||||
// wrapDisconnectError classifies a receive-loop failure before the run is torn
|
||||
// down. An auth rejection (PermissionDenied/Unauthenticated) means the session
|
||||
// needs re-login and retrying is futile, so mark it terminal (NeedsLogin) — run()
|
||||
// then exits on its own instead of spinning the backoff. Any other failure is a
|
||||
// recoverable connection reset that the backoff should retry.
|
||||
func (e *Engine) wrapDisconnectError(err error) {
|
||||
state := CtxGetState(e.ctx)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied || s.Code() == codes.Unauthenticated) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
return
|
||||
}
|
||||
_ = state.Wrap(ErrResetConnection)
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
@@ -1196,9 +1167,9 @@ func (e *Engine) receiveJobEvents() {
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1254,7 +1225,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
@@ -1280,9 +1251,9 @@ func (e *Engine) receiveManagementEvents() {
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1743,7 +1714,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() error {
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
@@ -1806,20 +1777,15 @@ func (e *Engine) receiveSignalEvents() error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// happens if signal is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if signal is unavailable for a long time, or rejects us
|
||||
// (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||
e.signal.WaitStreamConnected(e.ctx)
|
||||
if err := e.ctx.Err(); err != nil {
|
||||
return fmt.Errorf("wait for signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
e.signal.WaitStreamConnected()
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
|
||||
@@ -119,6 +119,10 @@ func (d *BindListener) ReadPackets() {
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.lazyConn.Close()
|
||||
d.bind.RemoveEndpoint(d.fakeIP)
|
||||
d.done.Done()
|
||||
|
||||
@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
|
||||
}
|
||||
|
||||
offer := h.buildOfferAnswer()
|
||||
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
|
||||
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||
}
|
||||
|
||||
func (h *Handshaker) sendAnswer() error {
|
||||
answer := h.buildOfferAnswer()
|
||||
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
|
||||
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalAnswer(answer, h.config.Key)
|
||||
}
|
||||
|
||||
@@ -192,7 +192,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
muxRelays sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
@@ -245,8 +244,8 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
}
|
||||
|
||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayMgr = manager
|
||||
}
|
||||
|
||||
@@ -907,8 +906,8 @@ func (d *Status) MarkSignalConnected() {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
@@ -1019,26 +1018,24 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.muxRelays.RLock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.relayMgr == nil {
|
||||
defer d.muxRelays.RUnlock()
|
||||
return slices.Clone(d.relayStates)
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
relayMgr := d.relayMgr
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
d.muxRelays.RUnlock()
|
||||
|
||||
states := relayMgr.RelayStates()
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range relayMgr.ServerURLs() {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
|
||||
@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
|
||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
||||
if *input.ServerSSHAllowed {
|
||||
log.Infof("enabling SSH server")
|
||||
} else {
|
||||
|
||||
@@ -242,35 +242,6 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
|
||||
// Configs written before ServerSSHAllowed was introduced lack the field and
|
||||
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
|
||||
// apply the value instead of panicking on a nil pointer dereference.
|
||||
tests := []struct {
|
||||
name string
|
||||
input *bool
|
||||
want bool
|
||||
}{
|
||||
{"enable", util.True(), true},
|
||||
{"disable", util.False(), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
|
||||
|
||||
config, err := UpdateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
ServerSSHAllowed: tt.input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
|
||||
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
|
||||
@@ -251,14 +251,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||
// the original client did not request EDNS0.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
if !hadEdns {
|
||||
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
@@ -268,13 +260,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||
}
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(reply)
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
|
||||
@@ -171,13 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
return connectClient.RunOniOS(cfg, fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -36,7 +36,6 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
@@ -52,19 +51,8 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||
// process (decoupled from the network extension), so without this the server
|
||||
// lingers after the user dismisses the browser and the next connect stalls
|
||||
// trying to bind the same port.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
@@ -72,24 +60,12 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||
// safe to call when no login is running.
|
||||
func (a *Auth) Stop() {
|
||||
if a.cancel != nil {
|
||||
a.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
|
||||
@@ -344,9 +344,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
||||
|
||||
@@ -5,7 +5,6 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/pprof"
|
||||
|
||||
@@ -28,11 +27,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}
|
||||
|
||||
var clientMetrics debug.MetricsExporter
|
||||
if s.connectClient != nil {
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
}
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +45,10 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
defer s.cleanupBundleCapture()
|
||||
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +112,7 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
s.connectClient.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
@@ -134,20 +126,13 @@ func (s *Server) SetSyncResponsePersistence(_ context.Context, req *proto.SetSyn
|
||||
|
||||
enabled := req.GetEnabled()
|
||||
s.persistSyncResponse = enabled
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
}
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
|
||||
return &proto.SetSyncResponsePersistenceResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
cClient := s.connectClient
|
||||
if cClient == nil {
|
||||
return nil, errors.New("connect client is not initialized")
|
||||
}
|
||||
|
||||
return cClient.GetLatestSyncResponse()
|
||||
return s.connectClient.GetLatestSyncResponse()
|
||||
}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon.
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -39,12 +38,11 @@ type conflictCheck struct {
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// 1. Stop the in-flight run via the supervisor (blocks until fully torn down).
|
||||
// 2. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// 3. Start a fresh run with the new config.
|
||||
// 4. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
@@ -52,39 +50,24 @@ type conflictCheck struct {
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
// Hold s.mutex for the entire restart sequence (stop + re-start). Any
|
||||
// concurrent Up/Down/Status arriving while MDM is restarting blocks on the
|
||||
// Lock until we are done — they then observe the post-restart state coherently.
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
// No run in flight, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel daemon-side login/status activities tied to the old run; the run
|
||||
// itself is torn down atomically by the supervisor inside Restart (see
|
||||
// restartEngineForMDMLocked), which stops and re-starts in one operation.
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
@@ -131,14 +114,13 @@ func (s *Server) publishConfigChangedEvent(source string) {
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
// (re-running applyMDMPolicy via Config.apply) and starts a fresh run.
|
||||
// Mirrors the tail of Server.Start so a runtime MDM change behaves
|
||||
// identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
// for the entire restart sequence so concurrent Up/Down/Status RPCs
|
||||
// observe a coherent post-restart state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
@@ -154,13 +136,13 @@ func (s *Server) restartEngineForMDMLocked() error {
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
log.Info("MDM restart: atomically restarting the run with re-resolved config")
|
||||
// MDM restart has no incoming RPC metadata; fire and forget. Restart is a
|
||||
// single supervisor op (atomic stop+start), so there is no observable
|
||||
// "stopped" window between tearing down the old run and starting the new.
|
||||
s.connectClient.Restart(config, nil)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -34,10 +34,6 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -147,10 +143,6 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -199,10 +191,6 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
|
||||
@@ -8,12 +8,10 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -39,15 +37,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
probeThreshold = time.Second * 5
|
||||
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
|
||||
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
|
||||
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
|
||||
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
|
||||
defaultInitialRetryTime = 30 * time.Minute
|
||||
defaultMaxRetryInterval = 60 * time.Minute
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
probeThreshold = time.Second * 5
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
@@ -72,15 +62,8 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
// Run state (in-flight? established/done channels?) is owned entirely by the
|
||||
// supervisor inside connectClient — the daemon keeps no per-run fields.
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -136,6 +119,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
// The ConnectClient is daemon-lifetime: build it exactly once, here. Its
|
||||
// supervisor lives as long as the daemon; Up/Down/MDM and reconnects all
|
||||
// drive this same instance. updateManager isn't ready yet (created in
|
||||
// Start) and is injected there via SetUpdateManager.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
@@ -147,7 +137,7 @@ func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.clientRunning {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,6 +155,7 @@ func (s *Server) Start() error {
|
||||
stateMgr := statemanager.New(s.profileManager.GetStatePath())
|
||||
s.updateManager = updater.NewManager(s.statusRecorder, stateMgr)
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
s.connectClient.SetUpdateManager(s.updateManager)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
@@ -190,7 +181,9 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
// actCancel cancels in-flight foreground operations (login/status); the run
|
||||
// itself is owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
// copy old default config
|
||||
@@ -232,99 +225,14 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
// Boot autoconnect: no incoming RPC metadata. The supervisor runs the
|
||||
// client and reconnects internally; we just fire and forget (the run owns
|
||||
// its established/done channels).
|
||||
s.connectClient.RunAsync(config, nil)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
|
||||
log.Debugf("run client connection exited with error: %v", err)
|
||||
}
|
||||
log.Tracef("client connection exited")
|
||||
return
|
||||
}
|
||||
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
go func() {
|
||||
t := time.NewTicker(24 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
mgmtState := statusRecorder.GetManagementState()
|
||||
signalState := statusRecorder.GetSignalState()
|
||||
if mgmtState.Connected && signalState.Connected {
|
||||
log.Tracef("resetting status")
|
||||
backOff.Reset()
|
||||
} else {
|
||||
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
runOperation := func() error {
|
||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
||||
if err != nil {
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("client connection exited gracefully, do not need to retry")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
@@ -720,13 +628,22 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
|
||||
// The client (and its supervisor) is built once in New(), so a nil here
|
||||
// never happens in production — Up is only reachable after New() has run and
|
||||
// the gRPC server is serving. The real case this guards is the daemon
|
||||
// SHUTTING DOWN: rootCtx is cancelled, the supervisor is no longer accepting
|
||||
// commands, so ServiceRunning() is false even though the client exists. Bail
|
||||
// loud instead of enqueuing a run that will never start. (nil only happens in
|
||||
// tests that build a Server without New(); ServiceRunning is nil-safe.)
|
||||
if !s.connectClient.ServiceRunning() {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("service is not running, start the netbird service for 'up' to take effect")
|
||||
}
|
||||
|
||||
// If a connection run is already in flight, the existing engine is on the
|
||||
// job — just wait for it. Otherwise fall through to start a fresh run.
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -764,14 +681,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
// actCancel cancels in-flight foreground ops (login/status); the run is
|
||||
// owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
// Forward the caller's gRPC metadata (e.g. UI user-agent) into the run.
|
||||
md, _ := metadata.FromIncomingContext(callerCtx)
|
||||
|
||||
if s.config == nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
@@ -812,35 +729,26 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.connectClient.RunAsync(s.config, md)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
// todo: handle potential race conditions
|
||||
// waitForUp blocks until the in-flight run becomes established (success) or ends
|
||||
// before that (failure). The wait is owned by the supervisor (via the client) —
|
||||
// the daemon holds no per-run state here.
|
||||
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return nil, fmt.Errorf("client gave up to connect")
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
if err := s.connectClient.WaitEstablishedOrDone(timeoutCtx); err != nil {
|
||||
log.Debugf("waiting for the connection to be established failed: %v", err)
|
||||
return nil, fmt.Errorf("connection not established: %w", err)
|
||||
}
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
@@ -935,11 +843,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
// Down engine work in the daemon.
|
||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
giveUpChan := s.clientGiveUpChan
|
||||
|
||||
// cleanupConnection stops the run through the supervisor, which blocks until
|
||||
// the run has fully unwound — no separate goroutine-quiescence wait needed.
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
s.mutex.Unlock()
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
@@ -948,20 +856,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||
if giveUpChan != nil {
|
||||
select {
|
||||
case <-giveUpChan:
|
||||
log.Debugf("client goroutine finished successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -972,38 +866,19 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
// to skip the engine shutdown entirely.
|
||||
var engine *internal.Engine
|
||||
if s.connectClient != nil {
|
||||
engine = s.connectClient.Engine()
|
||||
// Tear the client down through the lifecycle supervisor BEFORE cancelling
|
||||
// the retry context. Stop serializes on the supervisor queue and blocks
|
||||
// until the in-flight run has fully unwound (a clean, synchronous teardown).
|
||||
// It must run before actCancel: cancelling the context first would make
|
||||
// Stop observe a dead context and return early without waiting.
|
||||
if err := s.connectClient.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop the retry goroutine so it does not start a fresh run. The client
|
||||
// itself is daemon-lifetime and intentionally kept (a later Up reuses it).
|
||||
s.actCancel()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||
// making the run loop the sole owner of engine shutdown.
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.connectClient = nil
|
||||
s.isSessionActive.Store(false)
|
||||
|
||||
log.Infof("service is down")
|
||||
@@ -1138,7 +1013,7 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
|
||||
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient.ConnectionRunning() {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
@@ -1184,48 +1059,13 @@ func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-s.clientRunningChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// A run that hits a terminal auth failure now exits on its own (engine marks
|
||||
// NeedsLogin), so we no longer poll-and-cancel: we wait for the in-flight run
|
||||
// to become established or to end. With no run in flight this returns
|
||||
// immediately (errNoRunInFlight); either way we then report the status below.
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady {
|
||||
if err := s.connectClient.WaitEstablishedOrDone(ctx); err != nil && ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1263,10 +1103,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
@@ -1304,10 +1140,6 @@ func (s *Server) GetPeerSSHHostKey(
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
@@ -1474,17 +1306,13 @@ func (s *Server) WaitJWTToken(
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||
s.mutex.Lock()
|
||||
if !s.clientRunning {
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
s.mutex.Unlock()
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
|
||||
}
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
@@ -1538,10 +1366,6 @@ func isUnixRunningDesktop() bool {
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
@@ -1820,22 +1644,6 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
|
||||
log.Tracef("running client connection")
|
||||
client := internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
client.SetUpdateManager(s.updateManager)
|
||||
client.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
s.mutex.Lock()
|
||||
s.connectClient = client
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := client.Run(runningChan, s.logFile); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
@@ -1897,45 +1705,6 @@ func (s *Server) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
|
||||
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
|
||||
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
|
||||
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
|
||||
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
|
||||
multiplier := defaultRetryMultiplier
|
||||
|
||||
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
|
||||
// parse the multiplier from the environment variable string value to float64
|
||||
value, err := strconv.ParseFloat(envValue, 64)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
|
||||
} else {
|
||||
multiplier = value
|
||||
}
|
||||
}
|
||||
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: initialInterval,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: multiplier,
|
||||
MaxInterval: maxInterval,
|
||||
MaxElapsedTime: maxElapsedTime, // 14 days
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// parseEnvDuration parses the environment variable and returns the duration
|
||||
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
|
||||
if envValue := os.Getenv(envVar); envValue != "" {
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
return duration
|
||||
}
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
|
||||
}
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -15,14 +15,19 @@ import (
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
return &Server{
|
||||
rootCtx: context.Background(),
|
||||
ctx := context.Background()
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
// Honor the production invariant: the daemon-lifetime client always exists
|
||||
// (built in New). Server methods rely on s.connectClient being non-nil.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
return s
|
||||
}
|
||||
|
||||
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
|
||||
return internal.NewConnectClient(ctx, nil, nil)
|
||||
return internal.NewConnectClient(ctx, nil)
|
||||
}
|
||||
|
||||
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
|
||||
@@ -87,41 +92,36 @@ func TestConcurrentConnectClientAccess(t *testing.T) {
|
||||
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
|
||||
}
|
||||
|
||||
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
|
||||
// properly nils out connectClient.
|
||||
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
// TestCleanupConnection_KeepsClientStopsRunning validates that cleanupConnection
|
||||
// clears the daemon "up" intent but KEEPS the daemon-lifetime ConnectClient
|
||||
// (it is reused across Up/Down; only the run is stopped).
|
||||
func TestCleanupConnection_KeepsClientStopsRunning(t *testing.T) {
|
||||
s := newTestServer()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
s.clientRunning = true
|
||||
|
||||
err := s.cleanupConnection()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
assert.NotNil(t, s.connectClient, "connectClient is daemon-lifetime and must persist after cleanup")
|
||||
assert.False(t, s.connectClient.ConnectionRunning(), "no run should be in flight after cleanup")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestCleanState_NilConnectClient(t *testing.T) {
|
||||
// TestCleanState_NotConnected validates that CleanState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestCleanState_NotConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil // will cause error if it tries to proceed past the nil check
|
||||
s.profileManager = nil // will cause error if it tries to proceed
|
||||
|
||||
// Should not panic — the nil check should prevent calling Status() on nil
|
||||
assert.NotPanics(t, func() {
|
||||
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
|
||||
})
|
||||
}
|
||||
|
||||
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestDeleteState_NilConnectClient(t *testing.T) {
|
||||
// TestDeleteState_NotConnected validates that DeleteState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestDeleteState_NotConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
@@ -129,60 +129,6 @@ func TestDeleteState_NilConnectClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDownThenUp_StaleRunningChan documents the known state issue where
|
||||
// clientRunningChan from a previous connection is already closed, causing
|
||||
// waitForUp() to return immediately on reconnect.
|
||||
func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Simulate state after a successful connection
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
close(s.clientRunningChan) // closed when engine started
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
waitDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.waitForUp(ctx)
|
||||
waitDone <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-waitDone:
|
||||
assert.NoError(t, err, "waitForUp returns success on stale channel")
|
||||
// But connectClient is still nil — this is the stale state issue
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
|
||||
s.mutex.Unlock()
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
|
||||
// ConnectClient has nil Engine (before Run is called).
|
||||
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -61,65 +60,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -38,7 +37,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
|
||||
|
||||
// CleanState handles cleaning of states (performing cleanup operations)
|
||||
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
@@ -81,7 +80,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
|
||||
|
||||
// DeleteState handles deletion of states without cleanup
|
||||
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
|
||||
@@ -62,10 +62,6 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
}
|
||||
|
||||
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, nil, fmt.Errorf("engine not initialized")
|
||||
|
||||
@@ -3,7 +3,6 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -122,23 +121,6 @@ func (i *Info) SetFlags(
|
||||
}
|
||||
}
|
||||
|
||||
// removeAddresses drops network addresses whose IP matches any of the given
|
||||
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
|
||||
// address, which otherwise churns the meta as the interface comes and goes.
|
||||
func (i *Info) removeAddresses(ips ...netip.Addr) {
|
||||
if len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
filtered := i.NetworkAddresses[:0]
|
||||
for _, addr := range i.NetworkAddresses {
|
||||
if slices.Contains(ips, addr.NetIP.Addr()) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, addr)
|
||||
}
|
||||
i.NetworkAddresses = filtered
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
@@ -165,9 +147,7 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
// excludeIPs are dropped from the reported network addresses (e.g. our own
|
||||
// WireGuard overlay address, which otherwise churns the peer meta).
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||
processCheckPaths := make([]string, 0)
|
||||
for _, check := range checks {
|
||||
@@ -182,7 +162,6 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
|
||||
info := GetInfo(ctx)
|
||||
info.Files = files
|
||||
info.removeAddresses(excludeIPs...)
|
||||
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
|
||||
@@ -2,7 +2,6 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -44,42 +43,3 @@ func Test_NetAddresses(t *testing.T) {
|
||||
t.Errorf("no network addresses found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfo_RemoveAddresses(t *testing.T) {
|
||||
addr := func(cidr string) NetworkAddress {
|
||||
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
NetworkAddresses: []NetworkAddress{
|
||||
addr("192.168.1.7/24"),
|
||||
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
|
||||
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
|
||||
addr("fd00:1234::1/64"), // overlay v6
|
||||
},
|
||||
}
|
||||
|
||||
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
|
||||
info.removeAddresses(
|
||||
netip.MustParseAddr("100.76.70.97"),
|
||||
netip.MustParseAddr("fd00:1234::1"),
|
||||
)
|
||||
|
||||
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
|
||||
if len(info.NetworkAddresses) != len(want) {
|
||||
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
|
||||
}
|
||||
for i, w := range want {
|
||||
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
|
||||
t.Errorf("address[%d] = %s, want %s", i, got, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
|
||||
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
|
||||
info.removeAddresses()
|
||||
if len(info.NetworkAddresses) != 1 {
|
||||
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,7 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
// Skip link-local and multicast: they carry no routable peer info and the
|
||||
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
|
||||
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||
if ipNet.IP.IsLoopback() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
|
||||
t.Helper()
|
||||
ip, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %q: %v", cidr, err)
|
||||
}
|
||||
ipNet.IP = ip
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func TestToNetworkAddress_Filtering(t *testing.T) {
|
||||
const mac = "c8:4b:d6:b6:04:ac"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cidr string
|
||||
want bool
|
||||
}{
|
||||
{"ipv4 global", "10.65.16.181/23", true},
|
||||
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
|
||||
{"ipv4 loopback", "127.0.0.1/8", false},
|
||||
{"ipv6 loopback", "::1/128", false},
|
||||
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
|
||||
{"ipv4 link-local", "169.254.1.2/16", false},
|
||||
{"ipv6 multicast", "ff02::1/128", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
|
||||
if got != tt.want {
|
||||
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -418,14 +418,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
// Suppress the on-boot Quick Actions popup when the daemon
|
||||
// reports DisableAutoConnect=true — that flag carries both the
|
||||
// user's "Connect on Startup = off" preference AND any MDM-
|
||||
// enforced override (applyMDMPolicy writes the policy value
|
||||
// into the same Config field). See netbirdio/netbird#5744.
|
||||
if !s.disableAutoConnectFromDaemon() {
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
s.showQuickActionsUI()
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
}
|
||||
@@ -1345,40 +1338,6 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||
// the active profile has DisableAutoConnect=true. Used by the
|
||||
// --quick-actions startup path to suppress the on-boot popup when the
|
||||
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||
// converge on the same Config field because applyMDMPolicy writes the
|
||||
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||
// daemon hiccup does not silently swallow the popup.
|
||||
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||
return false
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||
return false
|
||||
}
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||
return false
|
||||
}
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||
return false
|
||||
}
|
||||
return srvCfg.GetDisableAutoConnect()
|
||||
}
|
||||
|
||||
// getSrvConfig from the service to show it in the settings window.
|
||||
func (s *serviceClient) getSrvConfig() {
|
||||
s.managementURL = profilemanager.DefaultManagementURL
|
||||
|
||||
56
docker/build-env/README.md
Normal file
56
docker/build-env/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Build environments
|
||||
|
||||
Dockerfiles that pin the same toolchain CI uses, so a developer can
|
||||
reproduce a CI build locally without installing platform SDKs on their
|
||||
workstation. The version pins in each `Dockerfile` must stay in lockstep
|
||||
with `.github/workflows/`.
|
||||
|
||||
## `android/`
|
||||
|
||||
Mirrors `.github/workflows/mobile-build-validation.yml` (`android_build`
|
||||
job). Carries Go 1.25.5, Adopt JDK 11, Android cmdline-tools 8512546,
|
||||
NDK 23.1.7779620 and gomobile pinned at the CI commit. Use it to
|
||||
produce `netbird.aar` from `./client/android`:
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-android docker/build-env/android
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
gomobile bind \
|
||||
-o netbird.aar \
|
||||
-javapkg=io.netbird.gomobile \
|
||||
-ldflags="-checklinkname=0 \
|
||||
-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
-X github.com/netbirdio/netbird/version.version=local" \
|
||||
./client/android
|
||||
```
|
||||
|
||||
To build the full Android APK, bind-mount the `android-client` repo as
|
||||
well and run its own `./gradlew assembleDebug` from inside the
|
||||
container (the gradle wrapper ships with `android-client`).
|
||||
|
||||
## `windows-cross/`
|
||||
|
||||
Cross-compiles Windows binaries from Linux using `mingw-w64`. Lets you
|
||||
verify that `GOOS=windows go build ./...` compiles cleanly without
|
||||
needing a Windows VM. Cannot run Windows tests — the `golang-test-windows`
|
||||
CI job executes on a native `windows-latest` runner with wintun.dll
|
||||
and PsExec, neither of which lives under Linux containers.
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
```
|
||||
|
||||
## What is NOT here
|
||||
|
||||
- **iOS / macOS**: cannot legally run macOS in Docker (Apple EULA),
|
||||
and Xcode is not redistributable. The `ios_build` CI job uses a
|
||||
`macos-latest` GitHub runner; locally you need a real Mac.
|
||||
|
||||
- **Native Windows tests**: see note above. The Linux+mingw image
|
||||
builds, it does not execute Windows-host code paths
|
||||
(registry, wintun, services, PsExec workflows).
|
||||
|
||||
When CI version pins change, update the corresponding `ARG` lines in
|
||||
the Dockerfiles and the README's table of versions.
|
||||
86
docker/build-env/android/Dockerfile
Normal file
86
docker/build-env/android/Dockerfile
Normal file
@@ -0,0 +1,86 @@
|
||||
# Android build environment.
|
||||
#
|
||||
# Mirrors the toolchain pinned by .github/workflows/mobile-build-validation.yml
|
||||
# so a `gomobile bind` against ./client/android in this image produces the
|
||||
# same netbird.aar that CI builds.
|
||||
#
|
||||
# Tooling versions (must stay in sync with the CI workflow):
|
||||
# - Ubuntu 22.04 (matches the ubuntu-latest GitHub runner)
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - Adopt JDK 11 (matches actions/setup-java@v3 java-version: 11, distribution: adopt)
|
||||
# - Android SDK cmdline-tools 8512546
|
||||
# - Android NDK 23.1.7779620
|
||||
# - gomobile commit v0.0.0-20251113184115-a159579294ab
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-android docker/build-env/android
|
||||
#
|
||||
# # bind the netbird checkout in and run the same gomobile command CI runs
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
# gomobile bind \
|
||||
# -o netbird.aar \
|
||||
# -javapkg=io.netbird.gomobile \
|
||||
# -ldflags="-checklinkname=0 \
|
||||
# -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
# -X github.com/netbirdio/netbird/version.version=local" \
|
||||
# ./client/android
|
||||
#
|
||||
# To build the full APK, mount the android-client repo too and run
|
||||
# `./gradlew assembleDebug` from /android-client (this image carries
|
||||
# gradle's prerequisites JDK + Android SDK but not the gradle wrapper —
|
||||
# that ships with android-client).
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Versions — bump in lockstep with .github/workflows/mobile-build-validation.yml.
|
||||
ARG GO_VERSION=1.25.5
|
||||
ARG ANDROID_CMDLINE_TOOLS_VERSION=8512546
|
||||
ARG ANDROID_NDK_VERSION=23.1.7779620
|
||||
ARG GOMOBILE_VERSION=v0.0.0-20251113184115-a159579294ab
|
||||
|
||||
ENV ANDROID_HOME=/opt/android-sdk
|
||||
ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${ANDROID_NDK_VERSION}
|
||||
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV CGO_ENABLED=0
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/platform-tools:${JAVA_HOME}/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
unzip \
|
||||
git \
|
||||
openjdk-11-jdk-headless \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod). actions/setup-go fetches the same tarball.
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
# Install Android SDK command-line tools, accept licenses, install NDK.
|
||||
RUN mkdir -p "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& curl -fsSL -o /tmp/cmdline.zip \
|
||||
"https://dl.google.com/android/repository/commandlinetools-linux-${ANDROID_CMDLINE_TOOLS_VERSION}_latest.zip" \
|
||||
&& unzip -q /tmp/cmdline.zip -d "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" \
|
||||
&& rm /tmp/cmdline.zip \
|
||||
&& yes | sdkmanager --licenses > /dev/null \
|
||||
&& sdkmanager --install "ndk;${ANDROID_NDK_VERSION}" "platform-tools" > /dev/null
|
||||
|
||||
# Install gomobile at the same commit CI pins. Don't run `gomobile init` here:
|
||||
# `init` resolves the NDK at runtime, do it on the first bind in the mounted
|
||||
# workspace so the cache lands on the host volume.
|
||||
RUN GOBIN=/usr/local/bin go install "golang.org/x/mobile/cmd/gomobile@${GOMOBILE_VERSION}" \
|
||||
&& gomobile version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Default entrypoint is a plain shell so the image is composable: callers pass
|
||||
# the full gomobile / gradle command they want to run.
|
||||
CMD ["/bin/bash"]
|
||||
63
docker/build-env/windows-cross/Dockerfile
Normal file
63
docker/build-env/windows-cross/Dockerfile
Normal file
@@ -0,0 +1,63 @@
|
||||
# Windows-cross build environment.
|
||||
#
|
||||
# Cross-compiles Windows .exe targets from a Linux container using
|
||||
# mingw-w64. Mirrors the toolchain set used by
|
||||
# .github/workflows/golang-test-windows.yml insofar as that is possible
|
||||
# without a Windows kernel.
|
||||
#
|
||||
# IMPORTANT — what this image CAN do:
|
||||
# - `GOOS=windows go build ./...` to validate that Windows builds compile
|
||||
# - CGO Windows cross-compile via x86_64-w64-mingw32-gcc when CGO_ENABLED=1
|
||||
# (matches CI's choco-installed mingw-w64)
|
||||
#
|
||||
# IMPORTANT — what this image CANNOT do:
|
||||
# - Run Windows binaries (no Windows kernel under Docker on Linux).
|
||||
# - Replicate the CI's `go test` runs which execute on a real
|
||||
# windows-latest runner (wintun.dll, PsExec, registry, etc.).
|
||||
# Use the CI for that or a native Windows VM.
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
#
|
||||
# # Cross-compile a static client (.exe) from Linux:
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'CGO_ENABLED=1 GOOS=windows GOARCH=amd64 \
|
||||
# CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \
|
||||
# go build -o netbird.exe ./client'
|
||||
#
|
||||
# # Just validate that everything *compiles* on Windows (no CGO):
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
#
|
||||
# Tooling versions (keep in sync with go.mod and any future explicit pin
|
||||
# documented in golang-test-windows.yml):
|
||||
# - Ubuntu 22.04
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - mingw-w64 (Ubuntu package — pin further if drift becomes a problem)
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG GO_VERSION=1.25.5
|
||||
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
git \
|
||||
build-essential \
|
||||
mingw-w64 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod).
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
2
go.mod
2
go.mod
@@ -341,7 +341,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
||||
|
||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -510,8 +510,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a h1:3CWK+yTvRKOcC0Q8VCTGy4l60TEb27CQVS7LkMxwjmw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
|
||||
@@ -1,616 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird Enterprise — Getting Started
|
||||
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||
# embedded identity provider. Owner is created via first-login flow.
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_secret() {
|
||||
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||
}
|
||||
|
||||
rand_b64_key() {
|
||||
openssl rand -base64 32
|
||||
}
|
||||
|
||||
check_nb_domain() {
|
||||
local domain="$1"
|
||||
if [[ -z "$domain" ]]; then
|
||||
echo "The domain cannot be empty." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_domain_resolves() {
|
||||
local domain="$1"
|
||||
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||
return 1
|
||||
}
|
||||
|
||||
read_nb_domain() {
|
||||
local value=""
|
||||
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if ! check_nb_domain "$value"; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
if ! check_domain_resolves "$value"; then
|
||||
echo "" > /dev/stderr
|
||||
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||
local confirm=""
|
||||
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||
read -r confirm < /dev/tty
|
||||
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
fi
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
# read_yes_no "<prompt>" [<default y|n>]
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
local counter=1
|
||||
while true; do
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||
break
|
||||
fi
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres is taking too long. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
done
|
||||
echo " done"
|
||||
set -e
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
check_openssl
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||
echo "Generated files already exist in $(pwd)."
|
||||
echo "If you want to reinitialize the environment, please remove them first:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||
echo "Be aware this will remove all data from the database."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
echo " Enables traffic events logging on the management server."
|
||||
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||
|
||||
echo ""
|
||||
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||
|
||||
echo ""
|
||||
|
||||
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||
|
||||
POSTGRES_USER="netbird"
|
||||
POSTGRES_DB="netbird"
|
||||
POSTGRES_PASSWORD=$(rand_secret)
|
||||
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||
|
||||
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
echo ""
|
||||
echo "Selected:"
|
||||
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Rendering files into $(pwd) ..."
|
||||
install -m 600 /dev/null .env
|
||||
render_env >> .env
|
||||
render_docker_compose > docker-compose.yml
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||
fi
|
||||
render_caddyfile > Caddyfile
|
||||
install -m 600 /dev/null config.yaml
|
||||
render_config_yaml >> config.yaml
|
||||
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
echo ""
|
||||
echo "Starting postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
sleep 2
|
||||
wait_postgres
|
||||
|
||||
echo ""
|
||||
echo "Starting remaining services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Done."
|
||||
echo ""
|
||||
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||
echo ""
|
||||
echo "Tail logs:"
|
||||
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
render_env() {
|
||||
cat <<EOF
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
# Domain
|
||||
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||
|
||||
# Image tags. Default to "latest"
|
||||
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# License keys
|
||||
EOF
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
fi
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
EOF
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# Postgres
|
||||
POSTGRES_USER=${POSTGRES_USER}
|
||||
POSTGRES_DB=${POSTGRES_DB}
|
||||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
# Datastore encryption
|
||||
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
|
||||
# Dashboard OIDC scopes
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_docker_compose() {
|
||||
render_compose_header
|
||||
render_compose_common
|
||||
render_compose_server
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
render_compose_flow
|
||||
fi
|
||||
render_compose_postgres
|
||||
render_compose_footer
|
||||
}
|
||||
|
||||
render_compose_header() {
|
||||
cat <<'EOF'
|
||||
x-default: &default
|
||||
restart: unless-stopped
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_common() {
|
||||
cat <<'EOF'
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
container_name: netbird-caddy
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||
container_name: netbird-dashboard
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- AUTH_AUDIENCE=netbird-dashboard
|
||||
- AUTH_CLIENT_ID=netbird-dashboard
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||
- NGINX_SSL_PORT=443
|
||||
- LETSENCRYPT_DOMAIN=
|
||||
- LETSENCRYPT_EMAIL=
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_server() {
|
||||
cat <<'EOF'
|
||||
netbird-server:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||
container_name: netbird-server
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
dashboard:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- '3478:3478/udp'
|
||||
volumes:
|
||||
- netbird_data:/var/lib/netbird
|
||||
- ./config.yaml:/etc/netbird/config.yaml
|
||||
command: ["--config", "/etc/netbird/config.yaml"]
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_flow() {
|
||||
cat <<'EOF'
|
||||
nats:
|
||||
<<: *default
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
|
||||
enricher:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||
container_name: netbird-enricher
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
volumes:
|
||||
- netbird_enricher:/var/lib/netbird
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_DATADIR=/var/lib/netbird
|
||||
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_METRICS_PORT=9091
|
||||
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||
|
||||
receiver:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||
container_name: netbird-receiver
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_FLOW_LISTEN_PORT=80
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_postgres() {
|
||||
cat <<'EOF'
|
||||
postgres:
|
||||
<<: *default
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_footer() {
|
||||
cat <<'EOF'
|
||||
volumes:
|
||||
netbird_data:
|
||||
EOF
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
netbird_nats_data:
|
||||
netbird_enricher:
|
||||
EOF
|
||||
fi
|
||||
cat <<'EOF'
|
||||
netbird_postgres:
|
||||
netbird_caddy_data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<'EOF'
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
X-Content-Type-Options "nosniff"
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
X-XSS-Protection "1; mode=block"
|
||||
-Server
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:80 {
|
||||
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||
}
|
||||
|
||||
{$CADDY_SECURE_DOMAIN}:443 {
|
||||
import security_headers
|
||||
# Signal (gRPC over h2c)
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||
# Management (gRPC over h2c + HTTP)
|
||||
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||
reverse_proxy /api/* netbird-server:80
|
||||
reverse_proxy /ws-proxy/* netbird-server:80
|
||||
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||
reverse_proxy /oauth2/* netbird-server:80
|
||||
# Relay (WebSocket multiplexed on the same port)
|
||||
reverse_proxy /relay* netbird-server:80
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
# Flow receiver (gRPC over h2c)
|
||||
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<'EOF'
|
||||
# Dashboard
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_config_yaml() {
|
||||
cat <<EOF
|
||||
# NetBird Enterprise server configuration.
|
||||
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||
|
||||
server:
|
||||
listenAddress: ":80"
|
||||
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
metricsPort: 9090
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
logLevel: "info"
|
||||
logFile: "console"
|
||||
|
||||
# TLS is terminated by Caddy in front; leave this block empty.
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
|
||||
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
auth:
|
||||
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
dashboardRedirectURIs:
|
||||
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
store:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||
|
||||
activityStore:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
trafficFlow:
|
||||
enabled: true
|
||||
address: "https://${NETBIRD_DOMAIN}:443"
|
||||
interval: "60s"
|
||||
EOF
|
||||
fi
|
||||
}
|
||||
|
||||
init_environment
|
||||
@@ -351,11 +351,6 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
# Record whether the operator explicitly pinned the server/proxy images via
|
||||
# env vars, so the agent-network preset can pick its own defaults without
|
||||
# clobbering an explicit override.
|
||||
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
@@ -403,53 +398,7 @@ configure_domain() {
|
||||
return 0
|
||||
}
|
||||
|
||||
apply_agent_network_preset() {
|
||||
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||
# inputs we still need from the operator are the domain (handled by
|
||||
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||
# and the ACME email — both honor env vars first and fall back to a
|
||||
# prompt only when unset. CrowdSec is intentionally off.
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||
# env override; otherwise pin the agent-network builds.
|
||||
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
|
||||
fi
|
||||
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
|
||||
fi
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||
fi
|
||||
|
||||
echo "" > /dev/stderr
|
||||
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
|
||||
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
|
||||
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||
echo " - CrowdSec: disabled" > /dev/stderr
|
||||
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
configure_reverse_proxy() {
|
||||
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
apply_agent_network_preset
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Prompt for reverse proxy type
|
||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||
|
||||
@@ -961,15 +910,6 @@ NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
LETSENCRYPT_DOMAIN=none
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||
# and exposes only the AI Observability + agent-network configuration
|
||||
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||
EOF
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1006,17 +946,6 @@ NB_PROXY_PROXY_PROTOCOL=true
|
||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||
EOF
|
||||
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
cat <<EOF
|
||||
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||
# ingress for agent-network synth services. Disables the public-facing
|
||||
# surface so the proxy serves only synth-generated routes (the
|
||||
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||
# listeners on the embedded netstack.
|
||||
NB_PROXY_PRIVATE=true
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||
cat <<EOF
|
||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||
@@ -1397,20 +1326,12 @@ print_builtin_traefik_instructions() {
|
||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||
fi
|
||||
echo ""
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.ai/pricing"
|
||||
echo " Documentation: https://docs.netbird.io/agent-network"
|
||||
else
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
fi
|
||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||
echo ""
|
||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||
echo ""
|
||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||
echo "NetBird Proxy:"
|
||||
@@ -1433,11 +1354,6 @@ print_builtin_traefik_instructions() {
|
||||
echo ""
|
||||
fi
|
||||
fi
|
||||
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||
echo "Note: The public domain is only for setting up secure connections."
|
||||
echo "Your APIs and agent services remain private and are never exposed publicly."
|
||||
echo ""
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
@@ -1,638 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird — community combined → Enterprise combined migration
|
||||
#
|
||||
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||
# existing files. Original docker-compose.yml and config.yaml are never
|
||||
# modified.
|
||||
#
|
||||
# Steps (all optional, asked interactively):
|
||||
# 1. Image swap — replace community images with enterprise cloud images.
|
||||
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||
#
|
||||
# To revert:
|
||||
# docker compose down
|
||||
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||
# # at the end of this script's run.
|
||||
# docker compose up -d
|
||||
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_yq() {
|
||||
if ! command -v yq &> /dev/null; then
|
||||
cat > /dev/stderr <<'EOF'
|
||||
yq is required to parse and update YAML safely.
|
||||
|
||||
macOS: brew install yq
|
||||
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_password() {
|
||||
openssl rand -hex 32
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
detect_combined_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_dashboard_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_config_yaml_host_path() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_data_volume() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_exposed_address() {
|
||||
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||
}
|
||||
|
||||
detect_compose_network() {
|
||||
local tag
|
||||
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||
case "$tag" in
|
||||
"!!seq")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
"!!map")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "default"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Build docker-compose.override.yml from the steps the operator selected.
|
||||
# Service names match what we detected on the operator's side.
|
||||
render_override() {
|
||||
cat <<EOF
|
||||
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||
|
||||
services:
|
||||
${DASHBOARD_SERVICE}:
|
||||
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||
|
||||
${COMBINED_SERVICE}:
|
||||
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||
|
||||
postgres:
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
environment:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: netbird
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
nats:
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
|
||||
flow-enricher:
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||
container_name: netbird-flow-enricher
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_DATADIR: /var/lib/netbird
|
||||
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_METRICS_PORT: 9091
|
||||
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||
|
||||
flow-receiver:
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||
container_name: netbird-flow-receiver
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_FLOW_LISTEN_PORT: 80
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-flow.tls=true
|
||||
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||
- traefik.http.routers.netbird-flow.priority=100
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Volume declarations for anything new the override introduced
|
||||
local has_volumes="no"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
has_volumes="yes"
|
||||
fi
|
||||
|
||||
if [[ "$has_volumes" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
volumes:
|
||||
EOF
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo " netbird_postgres:"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo " netbird_nats_data:"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||
# config.yaml. We don't touch the original file.
|
||||
render_enterprise_config() {
|
||||
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
|
||||
yq eval "
|
||||
.server.store.engine = \"postgres\" |
|
||||
.server.store.dsn = \"$pg_dsn\" |
|
||||
.server.activityStore.engine = \"postgres\" |
|
||||
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||
.server.authStore.engine = \"postgres\" |
|
||||
.server.authStore.dsn = \"$pg_dsn\"
|
||||
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
local flow_addr="${NETBIRD_DOMAIN}"
|
||||
yq eval -i "
|
||||
.server.trafficFlow.enabled = true |
|
||||
.server.trafficFlow.address = \"$flow_addr\" |
|
||||
.server.trafficFlow.interval = \"60s\"
|
||||
" "$ENTERPRISE_CONFIG_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
resolve_data_volume() {
|
||||
local short="$1"
|
||||
local actual
|
||||
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||
echo "$actual"
|
||||
return
|
||||
fi
|
||||
# Relative bind mount: docker-compose resolves it against the compose
|
||||
# file's directory, but `docker run -v` resolves it against the current
|
||||
# working directory. Normalize to an absolute path so both interpretations
|
||||
# agree (and the printed revert command works from any CWD).
|
||||
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||
local compose_dir
|
||||
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||
(
|
||||
cd "$compose_dir"
|
||||
cd "$(dirname "$short")"
|
||||
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||
)
|
||||
return
|
||||
fi
|
||||
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||
echo "$short"
|
||||
}
|
||||
|
||||
backup_sqlite() {
|
||||
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
local data_volume_actual
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||
docker run --rm \
|
||||
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
busybox \
|
||||
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||
local copied
|
||||
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||
if [[ -z "$copied" ]]; then
|
||||
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo " done"
|
||||
}
|
||||
|
||||
run_migrate_store() {
|
||||
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||
echo " done"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration() {
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_yq
|
||||
check_openssl
|
||||
|
||||
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||
|
||||
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||
echo "Migration artifacts already exist in $(pwd):"
|
||||
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo ""
|
||||
echo "Either you've already migrated, or a previous run was interrupted."
|
||||
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMBINED_SERVICE=$(detect_combined_service)
|
||||
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||
DATA_VOLUME=$(detect_data_volume)
|
||||
COMPOSE_NETWORK=$(detect_compose_network)
|
||||
|
||||
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DATA_VOLUME" ]]; then
|
||||
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Detected existing deployment:"
|
||||
echo " Combined service: $COMBINED_SERVICE"
|
||||
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||
echo " Data volume: $DATA_VOLUME"
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 1 — always (this is the point of the script)
|
||||
MIGRATE_IMAGES="yes"
|
||||
echo ""
|
||||
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||
NB_LICENSE_KEY=$(read_secret " License key")
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||
|
||||
# Step 2 — optional
|
||||
echo ""
|
||||
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||
echo " will be backed up automatically. To fully revert later, restore"
|
||||
echo " that backup and delete docker-compose.override.yml +"
|
||||
echo " config.yaml.enterprise."
|
||||
local confirm
|
||||
confirm=$(read_yes_no " Continue?" "y")
|
||||
if [[ "$confirm" != "yes" ]]; then
|
||||
MIGRATE_POSTGRES="no"
|
||||
echo " Skipping Postgres migration."
|
||||
else
|
||||
POSTGRES_PASSWORD=$(rand_password)
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||
echo ""
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
# Auth secret MUST match server.authSecret from config.yaml
|
||||
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||
fi
|
||||
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||
|
||||
# We need the encryption key from the existing config.yaml for the enricher
|
||||
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
ENABLE_FLOW="no"
|
||||
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||
fi
|
||||
}
|
||||
|
||||
apply_changes() {
|
||||
echo ""
|
||||
echo "Writing $OVERRIDE_FILE ..."
|
||||
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||
render_override > "$OVERRIDE_FILE"
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||
fi
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||
render_enterprise_config
|
||||
fi
|
||||
|
||||
# Persist secrets that the override file references via env interpolation.
|
||||
# We write them to a .env file in the current directory; docker compose
|
||||
# picks it up automatically.
|
||||
echo "Writing .env additions (mode 600) ..."
|
||||
local ENV_FILE=".env"
|
||||
touch "$ENV_FILE"
|
||||
chmod 600 "$ENV_FILE"
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
fi
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||
fi
|
||||
} >> "$ENV_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling enterprise images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo "Stopping existing services (volumes preserved) ..."
|
||||
$DOCKER_COMPOSE_COMMAND down
|
||||
|
||||
backup_sqlite
|
||||
|
||||
echo ""
|
||||
echo "Starting Postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
|
||||
# Wait for healthy
|
||||
local counter=0
|
||||
echo -n "Waiting for Postgres to become ready"
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
if [[ $counter -ge 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo " done"
|
||||
|
||||
run_migrate_store
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Bringing up all services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Migration complete."
|
||||
}
|
||||
|
||||
print_summary() {
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Summary"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Images: swapped to enterprise"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||
echo ""
|
||||
echo " Generated files (next to your docker-compose.yml):"
|
||||
echo " $OVERRIDE_FILE"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo " .env (license key + secrets, mode 600)"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||
echo ""
|
||||
echo " Tail logs:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " To revert"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
# Resolve project-prefixed volume names now (before override is removed).
|
||||
local pg_volume data_volume_actual
|
||||
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||
echo " docker volume rm $pg_volume"
|
||||
echo " # Restore SQLite from the backup created during this run:"
|
||||
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||
fi
|
||||
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration
|
||||
apply_changes
|
||||
print_summary
|
||||
@@ -497,7 +497,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
@@ -610,10 +610,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -13,7 +13,7 @@ const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
@@ -139,7 +139,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
@@ -147,6 +147,14 @@ func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return h.Sum64()
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
}
|
||||
|
||||
@@ -164,7 +164,9 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
@@ -173,7 +175,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta)
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -181,7 +183,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta)
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -189,7 +191,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta)
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -197,20 +199,29 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
@@ -224,10 +235,23 @@ func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
|
||||
@@ -33,8 +33,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -84,9 +82,6 @@ type ProxyServiceServer struct {
|
||||
// Manager for users
|
||||
usersManager users.Manager
|
||||
|
||||
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||
idpManager idp.Manager
|
||||
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
@@ -162,7 +157,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -171,7 +166,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
idpManager: idpManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
@@ -1708,7 +1702,22 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
@@ -1745,45 +1754,6 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||
// user or peer ID, and peer name or user email.
|
||||
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||
// principal so multiple peers owned by the same user share a single
|
||||
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||
// user.Email when linked, peer.Name when not.
|
||||
|
||||
// If the peer isn't associated with a user, return the peer info directly.
|
||||
if peer.UserID == "" {
|
||||
return peer.ID, peer.Name
|
||||
}
|
||||
|
||||
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||
// if an IdP is available, we gather details on the user from it.
|
||||
principalID := peer.UserID
|
||||
displayIdentity := peer.Name
|
||||
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
// IdP enrichment wins when available — the stored email column is a
|
||||
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||
if s.idpManager != nil {
|
||||
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||
displayIdentity = ud.Email
|
||||
} else if uerr != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||
}
|
||||
}
|
||||
|
||||
return principalID, displayIdentity
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
|
||||
@@ -3,19 +3,14 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
@@ -142,52 +137,6 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||
// panics if any unexpected method is invoked).
|
||||
type mockTunnelPeersManager struct {
|
||||
peers.Manager
|
||||
peer *peer.Peer
|
||||
peerErr error
|
||||
groups []*types.Group
|
||||
groupsErr error
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||
return m.peer, m.peerErr
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||
return m.peer, m.groups, m.groupsErr
|
||||
}
|
||||
|
||||
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||
// an IdP that knows nothing about the user.
|
||||
type mockTunnelIdpManager struct {
|
||||
idp.Manager
|
||||
email string
|
||||
hasData bool
|
||||
err error
|
||||
gotCalls int
|
||||
gotMeta []idp.AppMetadata
|
||||
}
|
||||
|
||||
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||
m.gotCalls++
|
||||
m.gotMeta = append(m.gotMeta, meta)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if !m.hasData {
|
||||
// This might not be a thing any of the actual IDP implementations do,
|
||||
// i.e. return a nil value with no error, but it seems valuable to test
|
||||
// that behavior here.
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -405,163 +354,6 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||
// (IdP email -> stored User.Email -> peer.Name).
|
||||
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||
const (
|
||||
domain = "app.example.com"
|
||||
accountID = "account1"
|
||||
peerID = "peer1"
|
||||
peerName = "peer-display-name"
|
||||
userID = "user1"
|
||||
)
|
||||
|
||||
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
peerUserID string
|
||||
storedUsers map[string]*types.User
|
||||
storedErr error
|
||||
noIdP bool
|
||||
idpEmail string
|
||||
idpHasData bool
|
||||
idpErr error
|
||||
expectEmail string
|
||||
expectUserID string
|
||||
expectIdPHit bool
|
||||
}{
|
||||
{
|
||||
name: "idp email wins over stored email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp returns empty email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "",
|
||||
idpHasData: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp has no data",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpHasData: false,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp errors",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpErr: errors.New("idp unreachable"),
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when no idp manager",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
noIdP: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored email is empty",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||
peerUserID: userID,
|
||||
storedUsers: map[string]*types.User{},
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "unlinked peer uses peer name and never consults idp",
|
||||
peerUserID: "",
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: peerID,
|
||||
expectIdPHit: false,
|
||||
},
|
||||
{
|
||||
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
noIdP: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: userID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||
},
|
||||
peersManager: &mockTunnelPeersManager{
|
||||
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||
},
|
||||
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||
}
|
||||
|
||||
var idpMock *mockTunnelIdpManager
|
||||
if !tt.noIdP {
|
||||
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||
server.idpManager = idpMock
|
||||
}
|
||||
|
||||
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||
Domain: domain,
|
||||
TunnelIp: "100.64.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.GetValid(), "expected access granted")
|
||||
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||
|
||||
if idpMock != nil {
|
||||
if tt.expectIdPHit {
|
||||
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||
require.Len(t, idpMock.gotMeta, 1)
|
||||
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||
} else {
|
||||
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta)
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
@@ -788,11 +788,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
ExtraDNSLabels: loginReq.GetDnsLabels(),
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
|
||||
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -1209,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
@@ -1258,10 +1254,7 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
check := toProtocolCheck(postureCheck)
|
||||
if check != nil {
|
||||
protoChecks = append(protoChecks, check)
|
||||
}
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
@@ -1285,9 +1278,5 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
}
|
||||
}
|
||||
|
||||
if len(protoCheck.Files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2045,7 +2045,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam
|
||||
Extra: &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
},
|
||||
LazyConnectionEnabled: true,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
|
||||
@@ -62,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -123,7 +123,7 @@ type Manager interface {
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
|
||||
@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks base method.
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
|
||||
}
|
||||
|
||||
// SyncUserJWTGroups mocks base method.
|
||||
|
||||
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -2046,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2067,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2091,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2101,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2163,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2204,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3326,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -106,9 +106,11 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
|
||||
for _, peerID := range mustExclude {
|
||||
assert.NotContains(t, affected, peerID, "peer must not be affected")
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,9 +251,7 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
|
||||
}
|
||||
}
|
||||
|
||||
// A disabled sibling router routes to nobody, so updating a resource on its network
|
||||
// must NOT refresh its peer (the enabled router carries the bridge instead).
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -276,18 +274,13 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh, enabledCh)
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, enabledCh)
|
||||
peerShouldNotReceiveUpdate(t, disabledCh)
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -305,7 +298,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *tes
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout")
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -682,9 +682,6 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
// A disabled router in the snapshot routes to nobody, so it is skipped when the
|
||||
// walk scans existing account data: a policy edit still folds the literal source
|
||||
// group, but not the disabled router's peer.
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -697,13 +694,11 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
// A disabled resource in the snapshot is skipped: the policy edit still folds the
|
||||
// literal source group, but the resource no longer bridges to its network's router.
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -715,9 +710,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
|
||||
@@ -96,54 +96,33 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
|
||||
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
|
||||
|
||||
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
|
||||
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
|
||||
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
@@ -154,44 +133,20 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Contains(t, directPeers, peerIDs[3])
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
|
||||
@@ -213,7 +168,8 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
@@ -338,7 +294,6 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -369,7 +324,6 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
Peer: peerIDs[4],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -419,11 +373,17 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.NotContains(t, groups, groupIDs[2])
|
||||
assert.NotContains(t, groups, groupIDs[3])
|
||||
assert.Empty(t, directPeers)
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
|
||||
assert.Contains(t, groups, groupIDs[2])
|
||||
assert.Contains(t, groups, groupIDs[3])
|
||||
assert.NotContains(t, groups, groupIDs[0])
|
||||
assert.NotContains(t, groups, groupIDs[1])
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
@@ -492,9 +452,8 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
@@ -515,7 +474,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
@@ -547,9 +506,8 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
|
||||
@@ -606,9 +564,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
|
||||
// peer3 is unrelated to the route; only its own map can change.
|
||||
// peer3 is unrelated
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
@@ -629,7 +587,6 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -702,13 +659,9 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
|
||||
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
|
||||
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
|
||||
// of group1, is a sibling of the changed peer and must NOT refresh.
|
||||
// peer0 is in group0 AND group1, so both policies apply
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
@@ -744,7 +697,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
|
||||
@@ -901,9 +854,8 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
|
||||
assert.NotContains(t, result, peerIDs[0])
|
||||
assert.NotContains(t, result, peerIDs[1])
|
||||
|
||||
// peerIDs[4] is in neither isolated policy; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
|
||||
@@ -1025,13 +977,12 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// A peer in no policy/route refreshes only itself — no other peer is affected.
|
||||
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
|
||||
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
|
||||
@@ -1381,7 +1332,6 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{"nr-grpA"},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1805,9 +1755,7 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
|
||||
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
|
||||
}
|
||||
|
||||
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
|
||||
// target peer does not pull in the service's proxy peer.
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1833,7 +1781,8 @@ func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
require.NoError(t, s.CreateService(ctx, svc))
|
||||
|
||||
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
|
||||
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
|
||||
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
|
||||
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {
|
||||
|
||||
@@ -6,12 +6,7 @@
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
|
||||
// account policies/resources/routers/routes/proxy services and their rules/targets)
|
||||
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
|
||||
// the CHANGE itself are processed regardless of Enabled, so disabling one still
|
||||
// refreshes the peers that lose access (the toggle is the observable change, and the
|
||||
// update carries the old∪new state).
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
@@ -66,8 +61,7 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
@@ -82,7 +76,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -180,24 +174,6 @@ type Change struct {
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
|
||||
// OutputPeerIDs are peers folded straight into the result without seeding their
|
||||
// group memberships into the walk. Use for the peer whose group membership changed:
|
||||
// the peer itself must refresh, but its OTHER groups did not change, so they must
|
||||
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
|
||||
// (correct when the peer's own attributes changed, e.g. IP/status).
|
||||
OutputPeerIDs []string
|
||||
|
||||
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
|
||||
// OPPOSITE side — they are never expanded to their own members. Use this when a
|
||||
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
|
||||
// group(s) here. The opposite side of the policies the group participates in
|
||||
// refreshes, but the group's other members (siblings) do not — nothing changed for
|
||||
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
|
||||
// members still refresh via the opposite-side fold, exactly when they genuinely
|
||||
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
|
||||
// the output, so a one-sided membership change never wakes the whole group.
|
||||
LinkGroups []string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
@@ -210,9 +186,7 @@ func (c Change) isEmpty() bool {
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0 &&
|
||||
len(c.LinkGroups) == 0 &&
|
||||
len(c.OutputPeerIDs) == 0
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
@@ -223,8 +197,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
@@ -242,84 +216,57 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
linkGroups: toSet(c.ChangedGroupIDs),
|
||||
outputGroups: toSet(c.ChangedGroupIDs),
|
||||
changedPeers: toSet(c.ChangedPeerIDs),
|
||||
affectedGroups: make(map[string]struct{}),
|
||||
affectedPeers: make(map[string]struct{}),
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
}
|
||||
// LinkGroups match policies/routes to find the opposite side but are NOT output:
|
||||
// they go into linkGroups only, never outputGroups, so their members never fold in.
|
||||
addAll(r.linkGroups, c.LinkGroups)
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
// These seeded groups are for MATCHING only — folding the changed entity's own
|
||||
// side is gated on outputGroups (the caller-reported groups), so a seeded group
|
||||
// never folds its whole membership; only the changed peer itself folds in.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeers) == 0 {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.linkGroups[groupID] = struct{}{}
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policySide selects which side of a policy rule to walk.
|
||||
type policySide int
|
||||
|
||||
const (
|
||||
sideSource policySide = iota
|
||||
sideDestination
|
||||
)
|
||||
|
||||
func (s policySide) opposite() policySide {
|
||||
if s == sideSource {
|
||||
return sideDestination
|
||||
}
|
||||
return sideSource
|
||||
}
|
||||
|
||||
// walk resolves affected peers in two buckets, by how far each change propagates.
|
||||
//
|
||||
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
|
||||
// posture check changed). Source AND destination refresh, so each such policy is
|
||||
// walked on both sides.
|
||||
//
|
||||
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
|
||||
// touches we fold only the side AWAY from the change:
|
||||
// - a changed peer/group sits ON a policy side -> fold the opposite side;
|
||||
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
|
||||
// the policies whose destination reaches it (and the routers it implies).
|
||||
//
|
||||
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
|
||||
// member peers, outside the policy graph, and are folded here too.
|
||||
func (r *resolver) walk() {
|
||||
for _, policy := range r.bothSidesPolicies() {
|
||||
r.foldPolicySide(policy, sideSource)
|
||||
r.foldPolicySide(policy, sideDestination)
|
||||
}
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
@@ -328,31 +275,7 @@ func (r *resolver) walk() {
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectFromChangedRoutes(r.change.Routes)
|
||||
r.collectFromChangedRouters(r.change.Routers)
|
||||
r.collectFromChangedResources(r.change.Resources)
|
||||
r.collectFromChangedNetworks(r.change.Networks)
|
||||
|
||||
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
|
||||
// refreshes the resolver's output (it ignores the separately-passed changed peers),
|
||||
// so the changed peer reaches its own new map only via here. An offline/deleted
|
||||
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
|
||||
addAll(r.affectedPeers, setToSlice(r.changedPeers))
|
||||
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
|
||||
// memberships were not seeded into the walk (only the changed group was).
|
||||
addAll(r.affectedPeers, r.change.OutputPeerIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into affectedGroups so expand() maps them to members, without the
|
||||
// policy/route walk that linkGroups would trigger.
|
||||
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
|
||||
}
|
||||
|
||||
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
|
||||
// plus those gated by a changed posture check. walk folds both their sides.
|
||||
func (r *resolver) bothSidesPolicies() []*types.Policy {
|
||||
policies := append([]*types.Policy(nil), r.change.Policies...)
|
||||
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
@@ -361,71 +284,27 @@ type resolver struct {
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
// Inputs — what changed. Set once at construction, read-only during the walk
|
||||
// (except linkGroups, which collectFromExplicitResources also seeds).
|
||||
//
|
||||
// linkGroups is the MATCH set: caller-changed groups ∪ the groups of changed
|
||||
// peers ∪ changed-resource groups. A rule/route/router matches the change when
|
||||
// one of its groups is here — used only to find the opposite side to fold.
|
||||
//
|
||||
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
|
||||
// matched group is here, its whole membership is affected. A peer-seeded group
|
||||
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
|
||||
// (changedPeers), never its siblings.
|
||||
linkGroups map[string]struct{}
|
||||
outputGroups map[string]struct{}
|
||||
changedPeers map[string]struct{}
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
|
||||
// is expanded to its member peers in expand().
|
||||
affectedGroups map[string]struct{}
|
||||
affectedPeers map[string]struct{}
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
}
|
||||
|
||||
// policies returns the account's ENABLED policies from the snapshot. Disabled
|
||||
// policies grant no access, so the walk skips them when scanning existing account
|
||||
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
|
||||
// processed regardless of Enabled, so disabling one still refreshes its peers.
|
||||
func (r *resolver) policies() []*types.Policy {
|
||||
enabled := make([]*types.Policy, 0, len(r.snap.policies))
|
||||
for _, policy := range r.snap.policies {
|
||||
if policy != nil && policy.Enabled {
|
||||
enabled = append(enabled, policy)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
|
||||
// networkResources / networkRouters return the account's ENABLED resources/routers
|
||||
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
|
||||
// it scans existing account data. The explicitly changed objects in the Change are
|
||||
// processed regardless of Enabled (collectFromChanged*), so disabling one still
|
||||
// refreshes the peers that lose access.
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
|
||||
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
|
||||
for _, resource := range r.snap.resources {
|
||||
if resource.Enabled {
|
||||
enabled = append(enabled, resource)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
|
||||
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
|
||||
for _, router := range r.snap.routers {
|
||||
if router.Enabled {
|
||||
enabled = append(enabled, router)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groups {
|
||||
for gID := range groupSet {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
@@ -438,25 +317,25 @@ func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.affectedGroups)
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.affectedPeers {
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in affectedGroups).
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.affectedGroups[groupID]; !linked {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
@@ -472,349 +351,169 @@ func (r *resolver) expand() []string {
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
|
||||
// side of a rule.
|
||||
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
|
||||
if side == sideDestination {
|
||||
return rule.Destinations
|
||||
}
|
||||
return rule.Sources
|
||||
}
|
||||
|
||||
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
|
||||
if side == sideDestination {
|
||||
return rule.DestinationResource
|
||||
}
|
||||
return rule.SourceResource
|
||||
}
|
||||
|
||||
// foldPolicySide folds one side of a policy down to affected peers: its groups
|
||||
// (resolved to members in expand) and its direct peer. When the side is the
|
||||
// DESTINATION and references a network resource (directly or via a destination
|
||||
// group's resources), it also folds the routers that serve that resource's network
|
||||
// — a destination resource is reached through its routers. A resource on the SOURCE
|
||||
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
|
||||
// so the router hop is destination-only.
|
||||
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
|
||||
if policy == nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
|
||||
}
|
||||
}
|
||||
|
||||
// appendPoliciesForPostureChecks appends every policy that references a changed
|
||||
// posture check (a rule change, so walk both sides).
|
||||
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return policies
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) || !policy.Enabled {
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
|
||||
policy.ID, policy.Name, postureCheckIDs)
|
||||
policies = append(policies, policy)
|
||||
}
|
||||
return policies
|
||||
}
|
||||
|
||||
// collectFromPolicies folds, for every policy whose rule a changed group or peer
|
||||
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
|
||||
// the changed entity's own side: the changed group's whole membership when the
|
||||
// group itself changed (outputGroups), or the changed peer alone when matched via a
|
||||
// peer-seeded group (never its co-members).
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue // a disabled rule grants no access
|
||||
}
|
||||
r.foldRuleSideIfChanged(policy, rule, sideSource)
|
||||
r.foldRuleSideIfChanged(policy, rule, sideDestination)
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
|
||||
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
|
||||
// the changed entity's own side (the whole changed group, or the changed peer alone).
|
||||
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
nearGroups := ruleSideGroups(rule, side)
|
||||
nearResource := ruleSideResource(rule, side)
|
||||
|
||||
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
|
||||
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
return
|
||||
}
|
||||
|
||||
// Opposite side, fully down to peers (a destination opposite also folds routers).
|
||||
r.foldPolicySideForRule(policy, rule, side.opposite())
|
||||
|
||||
// Own side: fold the whole changed group's members only when the group itself
|
||||
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
|
||||
// its siblings never refresh. The changed peers themselves are folded once, after
|
||||
// the walk (see walk()).
|
||||
for _, gID := range nearGroups {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// When the changed side IS a destination, the resources it targets are reached
|
||||
// through their network's routers, so those routers refresh too (e.g. attaching a
|
||||
// resource to a destination group, or a changed destination group/resource).
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
|
||||
// for a destination side the routers of that rule's destination resources.
|
||||
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
|
||||
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
|
||||
// (the changed entity) and the SOURCE side of every policy reaching a resource on
|
||||
// its network (the router serves the whole network). Sibling routers on the network
|
||||
// are independent and are NOT folded. Passing the old router state keeps a repointed
|
||||
// router's previous backing affected without a post-commit read.
|
||||
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
|
||||
// policies targeting EXACTLY that resource — directly, or via one of the resource's
|
||||
// own groups (old∪new across the change, so a now-detached group's sources still
|
||||
// refresh) — plus the routers serving its network (the resource is reached through
|
||||
// them). It does not touch sibling resources on the same network.
|
||||
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResource folds the source side of every policy whose
|
||||
// destination is the given resource — referenced directly, or via any of the given
|
||||
// groups (the resource's own old∪new groups, which captures a detached group).
|
||||
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
|
||||
groups := toSet(groupIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
}
|
||||
|
||||
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
|
||||
// resource directly, or one of the given destination groups.
|
||||
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
|
||||
return true
|
||||
}
|
||||
if anyInSet(rule.Destinations, groups) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
|
||||
// policies reaching any of its resources, plus its routers. A network has no
|
||||
// groups/peers of its own.
|
||||
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil || network.ID == "" {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
|
||||
resourceIDs := r.networkResourceIDs(network.ID)
|
||||
r.foldPolicySourcesForResources(resourceIDs)
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
|
||||
// destination targets one of resourceIDs (directly or via a destination group).
|
||||
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromRoutes folds, per matched route, the OPPOSITE side(s) fully and the
|
||||
// matched side's own groups only on a whole-group change (outputGroups). A route has
|
||||
// three peer sides — routing (Peer/PeerGroups), consumer (Groups) and ACL
|
||||
// (AccessControlGroups) — that each refresh the others; the changed side's own group
|
||||
// folds its siblings only when the group itself changed, never on a one-peer move.
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
if !rt.Enabled {
|
||||
continue // disabled routes route to nobody; skip existing account data
|
||||
}
|
||||
routing := anyInSet(rt.PeerGroups, r.linkGroups) || (rt.Peer != "" && isInSet(rt.Peer, r.changedPeers))
|
||||
consumer := anyInSet(rt.Groups, r.linkGroups)
|
||||
acl := anyInSet(rt.AccessControlGroups, r.linkGroups)
|
||||
if !routing && !consumer && !acl {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (routing=%t consumer=%t acl=%t) -> folding opposite sides; own side gated on outputGroups",
|
||||
rt.ID, routing, consumer, acl)
|
||||
r.foldRouteSide(rt.PeerGroups, routing)
|
||||
r.foldRouteSide(rt.Groups, consumer)
|
||||
r.foldRouteSide(rt.AccessControlGroups, acl)
|
||||
// The single routing Peer folds when the routing side is the OPPOSITE of the
|
||||
// match (consumer/acl need it), or when that very peer is the change.
|
||||
if rt.Peer != "" && (consumer || acl || isInSet(rt.Peer, r.changedPeers)) {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldRouteSide folds a route side: when this side is the one that matched, fold its
|
||||
// groups only on a whole-group change (outputGroups) so siblings of a single moved
|
||||
// peer stay put; otherwise it is an opposite side and folds fully.
|
||||
func (r *resolver) foldRouteSide(groups []string, matchedHere bool) {
|
||||
if matchedHere {
|
||||
r.foldOutputGroups(groups)
|
||||
return
|
||||
}
|
||||
addAll(r.affectedGroups, groups)
|
||||
}
|
||||
|
||||
// foldOutputGroups folds only the groups that the caller reported as wholly changed
|
||||
// (outputGroups). Used for a matched object's OWN side, where a peer-seeded or
|
||||
// link-only group must not pull in its siblings.
|
||||
func (r *resolver) foldOutputGroups(groups ...[]string) {
|
||||
for _, gs := range groups {
|
||||
for _, gID := range gs {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.linkGroups) == 0 {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.linkGroups) {
|
||||
// A nameserver group has no opposite side: a peer's DNS config depends only
|
||||
// on its own membership, so a one-peer move refreshes that peer alone (folded
|
||||
// elsewhere). Fold the referenced groups only on a whole-group change.
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a linked group -> folding its groups %v (outputGroups only)", ns.ID, ns.Groups)
|
||||
r.foldOutputGroups(ns.Groups)
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.linkGroups[gID]; ok {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
|
||||
// routing peer set moved): the router's own peers refresh and so do the sources of
|
||||
// the policies reaching its network's resources. Sibling routers on the network are
|
||||
// independent and are not folded.
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q (own groups on outputGroups) + sources reaching network resources",
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
// The backing PeerGroups are the matched (own) side: fold them only on a
|
||||
// whole-group change so a one-peer move does not wake sibling backing peers. The
|
||||
// opposite side (policy sources reaching the network) is folded below.
|
||||
r.foldOutputGroups(router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,48 +526,42 @@ func (r *resolver) collectFromProxyServices() {
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil || !svc.Enabled {
|
||||
continue // a disabled service proxies nothing; skip existing account data
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets; access groups %v on outputGroups only",
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.affectedPeers[pid] = struct{}{}
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled {
|
||||
continue // a disabled target forwards nothing
|
||||
}
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.affectedPeers[target.TargetId] = struct{}{}
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
// AccessGroups are the matched (own) side with no opposite to fold: a member's
|
||||
// proxy access is self-contained, so a one-peer move refreshes that peer alone.
|
||||
// Fold the groups only on a whole-group change.
|
||||
r.foldOutputGroups(svc.AccessGroups)
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.linkGroups) == 0 {
|
||||
return r.changedPeers
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.linkGroups)
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeers
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
|
||||
for id := range r.changedPeers {
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
@@ -877,36 +570,54 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
return merged
|
||||
}
|
||||
|
||||
// foldRoutersForResources folds the routers serving the networks of the given
|
||||
// resources (a destination resource is reached through its network's routers). It is
|
||||
// the resource -> network -> router hop used by foldPolicySide for a destination.
|
||||
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
|
||||
}
|
||||
|
||||
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
|
||||
// the direct DestinationResource plus the resources of its destination groups.
|
||||
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// networkResourceIDs returns the IDs of all resources on the given network.
|
||||
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if resource.NetworkID == networkID {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return resourceIDs
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
@@ -916,9 +627,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -939,9 +650,6 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
@@ -1006,20 +714,44 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
|
||||
}
|
||||
}
|
||||
|
||||
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
|
||||
// rules (a disabled rule grants no access).
|
||||
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
addAll(groups, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peers[rule.SourceResource.ID] = struct{}{}
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
@@ -1044,7 +776,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
|
||||
// the resolver folds in for a changed policy, for asserting the pure logic.
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
@@ -19,14 +19,7 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
for _, rule := range p.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
@@ -87,6 +80,26 @@ func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
@@ -94,9 +107,24 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
|
||||
@@ -520,12 +520,7 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
// A membership change affects only the peer itself and the opposite side of THIS
|
||||
// group's policies — not the group's other members, and not the peer's other
|
||||
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
|
||||
// refreshes the peer without seeding its other group memberships. For an
|
||||
// intra-group policy the opposite side is the group, so its members still refresh.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
@@ -591,11 +586,10 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
|
||||
// policies refresh, not the group's other members or the peer's other groups. The
|
||||
// peer is no longer in the group's index, but LinkGroups still drives the
|
||||
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
@@ -606,6 +600,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
|
||||
@@ -217,7 +217,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
|
||||
@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -39,7 +39,7 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
@@ -102,6 +102,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -188,40 +192,27 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePeerLocation looks up the geo location for realIP, returning nil when
|
||||
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
|
||||
// what the peer already has, or the lookup failed. Geo lookups are skipped on
|
||||
// same-IP reconnects since they are comparatively expensive. The returned value
|
||||
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
|
||||
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
|
||||
return nil
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
CityName: location.City.Names.En,
|
||||
GeoNameID: location.City.GeonameID,
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -730,7 +721,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
|
||||
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
|
||||
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
@@ -989,9 +980,10 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
var peer *nbpeer.Peer
|
||||
var ipv6CapabilityChanged bool
|
||||
var metaDiff nbpeer.MetaDiff
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var err error
|
||||
var postureChecks []*posture.Checks
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -1019,16 +1011,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
|
||||
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if metaDiff.Updated() {
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -1036,11 +1037,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
@@ -1051,10 +1047,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged(), metaDiff.HostnameChanged()) {
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1063,29 +1058,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
|
||||
var reason string
|
||||
switch {
|
||||
case isStatusChanged:
|
||||
reason = "status changed"
|
||||
case updateAccountPeers:
|
||||
reason = "update account peers"
|
||||
case ipv6CapabilityChanged:
|
||||
reason = "ipv6 capability changed"
|
||||
case metaDiffAffectsPosture:
|
||||
reason = "meta diff affects posture"
|
||||
case versionChanged:
|
||||
reason = "version changed"
|
||||
case hostname:
|
||||
reason = "hostname changed"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("peer update required: %s", reason)
|
||||
return true
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
@@ -1094,8 +1066,8 @@ func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers
|
||||
// metadata change that flips a posture result removes this peer from others'
|
||||
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
|
||||
// back to the resolver.
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
|
||||
if peerNotValid || metaChangeAffectedPosture {
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
|
||||
if peerNotValid || (metaUpdated && hasPostureChecks) {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
@@ -1152,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var shouldStorePeer, shouldUpdatePeers bool
|
||||
var shouldStorePeer bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1179,10 +1151,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
shouldUpdatePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
@@ -1204,15 +1180,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.Meta = login.Meta
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
@@ -1222,7 +1190,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if shouldUpdatePeers {
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1318,22 +1286,12 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
@@ -1341,16 +1299,32 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
|
||||
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
|
||||
for _, peerID := range peerGroupIDs {
|
||||
groupIDsMap[peerID] = struct{}{}
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1362,7 +1336,11 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
|
||||
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||
}
|
||||
|
||||
@@ -1375,19 +1353,29 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
}
|
||||
|
||||
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
||||
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
|
||||
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
if slices.Contains(peerGroupIDs, sourceGroup) {
|
||||
return policy.SourcePostureChecks
|
||||
group, ok := sourceGroups[sourceGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to check peer in policy source group")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return policy.SourcePostureChecks, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
@@ -107,15 +103,6 @@ type Location struct {
|
||||
GeoNameID uint // city level geoname id
|
||||
}
|
||||
|
||||
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
|
||||
// IP.Equal, not ==.
|
||||
func (l Location) equal(other Location) bool {
|
||||
return l.CountryCode == other.CountryCode &&
|
||||
l.CityName == other.CityName &&
|
||||
l.GeoNameID == other.GeoNameID &&
|
||||
l.ConnectionIP.Equal(other.ConnectionIP)
|
||||
}
|
||||
|
||||
// NetworkAddress is the IP address with network and MAC address of a network interface
|
||||
type NetworkAddress struct {
|
||||
NetIP netip.Prefix `gorm:"serializer:json"`
|
||||
@@ -175,7 +162,49 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return len(metaDiff(p, other)) == 0
|
||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
||||
})
|
||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
||||
})
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
if !equalNetworkAddresses {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.Files, func(i, j int) bool {
|
||||
return p.Files[i].Path < p.Files[j].Path
|
||||
})
|
||||
sort.Slice(other.Files, func(i, j int) bool {
|
||||
return other.Files[i].Path < other.Files[j].Path
|
||||
})
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
if !equalFiles {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
p.GoOS == other.GoOS &&
|
||||
p.Kernel == other.Kernel &&
|
||||
p.KernelVersion == other.KernelVersion &&
|
||||
p.Core == other.Core &&
|
||||
p.Platform == other.Platform &&
|
||||
p.OS == other.OS &&
|
||||
p.OSVersion == other.OSVersion &&
|
||||
p.WtVersion == other.WtVersion &&
|
||||
p.UIVersion == other.UIVersion &&
|
||||
p.SystemSerialNumber == other.SystemSerialNumber &&
|
||||
p.SystemProductName == other.SystemProductName &&
|
||||
p.SystemManufacturer == other.SystemManufacturer &&
|
||||
p.Environment.Cloud == other.Environment.Cloud &&
|
||||
p.Environment.Platform == other.Environment.Platform &&
|
||||
p.Flags.isEqual(other.Flags) &&
|
||||
capabilitiesEqual(p.Capabilities, other.Capabilities)
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEmpty() bool {
|
||||
@@ -265,173 +294,26 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
|
||||
// new information is provided. newLocation is the geo location resolved from the
|
||||
// peer's current connection IP, or nil when there is nothing to apply (geo
|
||||
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
|
||||
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
|
||||
// diff.Updated() reports whether the peer needs to be persisted.
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
if meta.isEmpty() {
|
||||
return MetaDiff{}
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
versionChanged = p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
effectiveLocation := p.Location
|
||||
if newLocation != nil {
|
||||
effectiveLocation = *newLocation
|
||||
if p.Meta.isEqual(meta) {
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
|
||||
if diff.Updated() {
|
||||
p.Meta = meta
|
||||
}
|
||||
p.Location = effectiveLocation
|
||||
|
||||
if diff.Updated() {
|
||||
log.WithContext(ctx).Debug(diff.LogSummary())
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
|
||||
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
|
||||
// checks read it). Changed lists what moved, for logging and the persistence decision;
|
||||
// the snapshots let a posture check be replayed against old and new. Everything is derived
|
||||
// from these fields, so there are no parallel per-field flags to keep in sync.
|
||||
type MetaDiff struct {
|
||||
OldMeta PeerSystemMeta
|
||||
NewMeta PeerSystemMeta
|
||||
OldLocation Location
|
||||
NewLocation Location
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
|
||||
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
|
||||
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
|
||||
func (d *MetaDiff) Updated() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
|
||||
func (d *MetaDiff) VersionChanged() bool {
|
||||
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
|
||||
}
|
||||
|
||||
// HostnameChanged reports whether the peer's hostname changed.
|
||||
func (d *MetaDiff) HostnameChanged() bool {
|
||||
return d.OldMeta.Hostname != d.NewMeta.Hostname
|
||||
}
|
||||
|
||||
// LogSummary renders the changed fields as a single human-readable line.
|
||||
func (d *MetaDiff) LogSummary() string {
|
||||
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
|
||||
len(d.Changed), strings.Join(d.Changed, ", "))
|
||||
}
|
||||
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
|
||||
}
|
||||
|
||||
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
|
||||
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
|
||||
// list, so the log line and the persistence decision can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
|
||||
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
if !oldLocation.equal(newLocation) {
|
||||
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
// same multiplicity, ignoring order. The element type is the comparison key, so
|
||||
// every field participates in equality.
|
||||
func sameMultiset[T comparable](a, b []T) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
counts := make(map[T]int, len(a))
|
||||
for _, v := range a {
|
||||
counts[v]++
|
||||
}
|
||||
for _, v := range b {
|
||||
counts[v]--
|
||||
if counts[v] == 0 {
|
||||
delete(counts, v)
|
||||
}
|
||||
}
|
||||
return len(counts) == 0
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
// GetLastLogin returns the last login time of the peer.
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
|
||||
// map 1:1 to a single diff entry. Today the only such field is Environment, which
|
||||
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
|
||||
// entry beyond its single struct field. If you teach metaDiff to explode another
|
||||
// field into N entries, bump this by N-1; if you collapse a field, lower it.
|
||||
const metaDiffExtraEntries = 1
|
||||
|
||||
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
|
||||
// values and diffs it against the zero value, then asserts metaDiff emits exactly
|
||||
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
|
||||
//
|
||||
// The expected count is derived from the struct via reflection, so adding a field
|
||||
// to PeerSystemMeta raises the expectation automatically — but the actual diff
|
||||
// only grows if metaDiff was taught to compare the new field. A mismatch means
|
||||
// someone changed the struct without updating metaDiff (or this test's
|
||||
// extra-entry accounting), which is exactly what we want to catch.
|
||||
func TestMetaDiff_CoversAllFields(t *testing.T) {
|
||||
var full PeerSystemMeta
|
||||
exported := populateAll(t, reflect.ValueOf(&full).Elem())
|
||||
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
|
||||
|
||||
diff := metaDiff(PeerSystemMeta{}, full)
|
||||
|
||||
require.Len(t, diff, exported+metaDiffExtraEntries,
|
||||
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
|
||||
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
|
||||
"diff was: %v", diff)
|
||||
|
||||
require.False(t, full.isEqual(PeerSystemMeta{}),
|
||||
"isEqual must report a fully-populated meta as different from the zero value")
|
||||
}
|
||||
|
||||
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
|
||||
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
|
||||
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
|
||||
// compare would not change the diff count. This flips each Flags field on its own
|
||||
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
|
||||
// fails here.
|
||||
func TestFlags_isEqualChecksEveryField(t *testing.T) {
|
||||
typ := reflect.TypeOf(Flags{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
require.Equal(t, reflect.Bool, f.Type.Kind(),
|
||||
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
|
||||
|
||||
var a, b Flags
|
||||
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
|
||||
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateAll sets every exported field of the struct to a deterministic non-zero
|
||||
// value, recursing into nested structs and the element type of struct slices so
|
||||
// that each leaf differs from zero. It returns the number of exported fields on
|
||||
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
|
||||
// settable exported fields and is comparable with ==).
|
||||
func populateAll(t *testing.T, v reflect.Value) int {
|
||||
t.Helper()
|
||||
|
||||
typ := v.Type()
|
||||
exported := 0
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.PkgPath != "" { // unexported
|
||||
continue
|
||||
}
|
||||
exported++
|
||||
setNonZero(t, v.Field(i))
|
||||
}
|
||||
return exported
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
|
||||
// recursing into nested structs and populating one element of slice fields.
|
||||
func setNonZero(t *testing.T, field reflect.Value) {
|
||||
t.Helper()
|
||||
|
||||
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
|
||||
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
|
||||
return
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero")
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.SetFloat(7)
|
||||
case reflect.Struct:
|
||||
populateAll(t, field)
|
||||
case reflect.Slice:
|
||||
s := reflect.MakeSlice(field.Type(), 1, 1)
|
||||
setNonZero(t, s.Index(0))
|
||||
field.Set(s)
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -2894,141 +2893,3 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
|
||||
// returns a record built from the configured city geoname id, or an error when set.
|
||||
type fakeGeo struct {
|
||||
geoNameID uint
|
||||
isoCode string
|
||||
cityName string
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
|
||||
if g.err != nil {
|
||||
return nil, g.err
|
||||
}
|
||||
record := &geolocation.Record{}
|
||||
record.City.GeonameID = g.geoNameID
|
||||
record.City.Names.En = g.cityName
|
||||
record.Country.ISOCode = g.isoCode
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) Stop() error { return nil }
|
||||
|
||||
func TestResolvePeerLocation(t *testing.T) {
|
||||
realIP := net.ParseIP("203.0.113.10")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
geo geolocation.Geolocation
|
||||
peer *nbpeer.Peer
|
||||
realIP net.IP
|
||||
want *nbpeer.Location
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "no geo configured returns nil",
|
||||
geo: nil,
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil real IP returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "lookup error returns nil",
|
||||
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP and same geoname returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP but changed geoname returns location",
|
||||
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City B",
|
||||
GeoNameID: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IP returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("198.51.100.7"),
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no prior location returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
am := &DefaultAccountManager{geo: tt.geo}
|
||||
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
|
||||
if tt.wantNil {
|
||||
assert.Nil(t, got, "resolved location should be nil")
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got, "resolved location should not be nil")
|
||||
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
|
||||
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
|
||||
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
|
||||
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
|
||||
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
|
||||
return &nbpeer.MetaDiff{
|
||||
OldMeta: oldMeta,
|
||||
NewMeta: newMeta,
|
||||
OldLocation: oldLoc,
|
||||
NewLocation: newLoc,
|
||||
}
|
||||
}
|
||||
|
||||
func checks(def ChecksDefinition) []*Checks {
|
||||
return []*Checks{{Checks: def}}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NilDiff(t *testing.T) {
|
||||
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
})))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NBVersion(t *testing.T) {
|
||||
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
oldVer, newVer string
|
||||
want bool
|
||||
}{
|
||||
{"both above min, no flip", "1.3.0", "1.4.0", false},
|
||||
{"both below min, no flip", "1.0.0", "1.1.0", false},
|
||||
{"crosses up below->above", "1.1.0", "1.3.0", true},
|
||||
{"crosses down above->below", "1.3.0", "1.1.0", true},
|
||||
{"unparsable old only -> flip", "garbage", "1.3.0", true},
|
||||
{"unparsable both -> no flip", "garbage", "junk", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
|
||||
}})
|
||||
|
||||
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
|
||||
withinMin := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
|
||||
|
||||
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
|
||||
crossesDown := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
|
||||
// failing linux kernel flips the verdict pass -> fail.
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
|
||||
// flips the verdict.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
|
||||
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
|
||||
// A tracked process stays running while an unrelated file is added: the verdict does
|
||||
// not move, so posture is not affected.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/bar", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_GeoLocation(t *testing.T) {
|
||||
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{{CountryCode: "DE"}},
|
||||
}})
|
||||
|
||||
// Moving within allowed countries keeps the verdict; moving out flips it.
|
||||
stayAllowed := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
|
||||
|
||||
moveOut := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE"},
|
||||
nbpeer.Location{CountryCode: "FR"},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
|
||||
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
|
||||
// moving within it does not.
|
||||
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
|
||||
}})
|
||||
|
||||
movesOutOfRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
|
||||
|
||||
staysInRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
|
||||
// Hostname changes but no check reads it: not affected even with checks present.
|
||||
c := checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
|
||||
})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
|
||||
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
|
||||
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NoChecks(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
|
||||
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, nil))
|
||||
}
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -53,46 +51,6 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
|
||||
// replays each check against the peer's old and new state and compares verdicts, so a
|
||||
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
|
||||
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
|
||||
// evaluation error counts.
|
||||
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
|
||||
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
|
||||
|
||||
for _, c := range checks {
|
||||
for _, check := range c.GetChecks() {
|
||||
if verdictChanged(ctx, check, oldPeer, newPeer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verdictChanged replays check against old and new state and reports whether the verdict
|
||||
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
|
||||
// verdict (no change), an error on one side only is a flip.
|
||||
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
|
||||
oldPass, oldErr := check.Check(ctx, oldPeer)
|
||||
newPass, newErr := check.Check(ctx, newPeer)
|
||||
|
||||
oldVerdict := oldPass && (oldErr == nil)
|
||||
newVerdict := newPass && (newErr == nil)
|
||||
changed := oldVerdict != newVerdict
|
||||
|
||||
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
|
||||
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
|
||||
|
||||
return changed
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user