mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-22 15:59:59 +00:00
Compare commits
37 Commits
feat/admin
...
client_lif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7706f578fe | ||
|
|
daf5026192 | ||
|
|
ec18b07959 | ||
|
|
9628f016da | ||
|
|
b39e9df194 | ||
|
|
0388e0f262 | ||
|
|
86f896723d | ||
|
|
29ee84999c | ||
|
|
0e8fd22f36 | ||
|
|
ff98105212 | ||
|
|
6465997a69 | ||
|
|
3204270c4b | ||
|
|
6d3bcef2c4 | ||
|
|
5d7cb30e5b | ||
|
|
aff5da2c8e | ||
|
|
9b179be324 | ||
|
|
33e7b6a8f1 | ||
|
|
e0cff5e240 | ||
|
|
0085aebf77 | ||
|
|
91d2d341b7 | ||
|
|
8d46580c13 | ||
|
|
b42fe6a10f | ||
|
|
0f5d7fdc07 | ||
|
|
13c78d98f5 | ||
|
|
d1229ed84c | ||
|
|
9758145517 | ||
|
|
200a5a6a70 | ||
|
|
1f7b1ea863 | ||
|
|
4abb10c1aa | ||
|
|
a45cefe57a | ||
|
|
a6d504633f | ||
|
|
70f2097fff | ||
|
|
befa9a879c | ||
|
|
4152c41796 | ||
|
|
8b76b3d824 | ||
|
|
0503a18644 | ||
|
|
ec6512d660 |
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -59,12 +59,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
|
||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
id: test
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
|
||||
52
.github/workflows/golang-test-linux.yml
vendored
52
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -119,12 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -162,7 +162,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -175,12 +175,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,12 +246,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -306,12 +306,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -363,12 +363,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -407,7 +407,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -424,12 +424,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -484,7 +484,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -529,12 +529,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -623,12 +623,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -692,12 +692,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -734,7 +734,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
4
.github/workflows/golang-test-windows.yml
vendored
4
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
10
.github/workflows/mobile-build-validation.yml
vendored
10
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
@@ -54,11 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
@@ -221,7 +221,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -374,7 +374,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -464,12 +464,12 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -522,7 +522,7 @@ jobs:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -534,13 +534,13 @@ jobs:
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
12
.github/workflows/test-infrastructure-files.yml
vendored
12
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,12 +68,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/management:latest .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||
docker build -t netbirdio/relay:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -44,11 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
@@ -247,7 +247,7 @@ dockers_v2:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
@@ -295,7 +295,7 @@ dockers_v2:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
@@ -317,7 +317,7 @@ dockers_v2:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
@@ -339,7 +339,7 @@ dockers_v2:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
@@ -361,7 +361,7 @@ dockers_v2:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
@@ -383,7 +383,7 @@ dockers_v2:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
@@ -405,7 +405,7 @@ dockers_v2:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
@@ -462,13 +462,9 @@ checksum:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
@@ -151,9 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -186,9 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
return connectClient.RunOnAndroid(cfg, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
|
||||
@@ -138,23 +138,26 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
profileName := args[0]
|
||||
|
||||
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
@@ -163,6 +166,7 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
@@ -326,19 +330,3 @@ func wrapAmbiguityError(err error, handle string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||
// and returns the new profile's ID. It is the single entry point for profile
|
||||
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||
// <name>` auto-create path.
|
||||
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add profile failed: %w", err)
|
||||
}
|
||||
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
@@ -261,17 +262,46 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
return prefix + upper
|
||||
}
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
// DialClientGRPCServer returns client connection to the daemon server. It waits
|
||||
// (up to the timeout) for the daemon to become reachable so an `up` issued right
|
||||
// after `service start` tolerates the startup race. Instead of grpc's blocking
|
||||
// dial — whose raw "transport failed" retry warnings are silenced by the logger
|
||||
// config — we drive the wait ourselves and emit one clean line per failed attempt.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
strings.TrimPrefix(addr, "tcp://"),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.Connect()
|
||||
for {
|
||||
state := conn.GetState()
|
||||
if state == connectivity.Ready {
|
||||
return conn, nil
|
||||
}
|
||||
// Log only once the connection has actually failed — not during the
|
||||
// brief Idle/Connecting phase on a healthy daemon (avoids a spurious
|
||||
// line + wait when the daemon is already up).
|
||||
if state == connectivity.TransientFailure {
|
||||
log.Infof("waiting for the netbird daemon to become available at %s...", addr)
|
||||
}
|
||||
// Wake on the next state change, but at least every second so a stuck
|
||||
// TransientFailure re-logs at a steady cadence until the timeout.
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, time.Second)
|
||||
conn.WaitForStateChange(waitCtx, state)
|
||||
waitCancel()
|
||||
if ctx.Err() != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("daemon not reachable at %s: %w", addr, ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackOff execute function in backoff cycle.
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -110,10 +111,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve the active profile's display name via the daemon, which runs
|
||||
// as root and can read the per-user profile files. The local profile
|
||||
// manager only knows the active profile ID, not its display name.
|
||||
profName := getActiveProfileName(ctx)
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||
Anonymize: anonymizeFlag,
|
||||
@@ -165,25 +167,6 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getActiveProfileName asks the daemon for the active profile's display
|
||||
// name. The daemon runs as root and can read the per-user profile files to
|
||||
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||
// error so status output degrades gracefully.
|
||||
func getActiveProfileName(ctx context.Context) string {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return resp.GetProfileName()
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "idle", "connecting", "connected":
|
||||
|
||||
@@ -128,9 +128,15 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
profileSwitched = true
|
||||
}
|
||||
|
||||
@@ -145,52 +151,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||
}
|
||||
|
||||
// switchOrCreateProfile switches the active profile to the one identified by
|
||||
// handle, creating it first when it does not exist yet. This restores the
|
||||
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||
// missing profile instead of failing.
|
||||
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok || st.Code() != codes.NotFound {
|
||||
return err
|
||||
}
|
||||
// Don't fail immediately on a create error: a concurrent run may
|
||||
// have created the profile between the NotFound above and this
|
||||
// call, in which case the retried switch still succeeds. Only
|
||||
// surface the create error if the switch also fails.
|
||||
_, createErr := createProfile(ctx, handle, username)
|
||||
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||
if createErr != nil {
|
||||
return fmt.Errorf("create profile: %w", createErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProfile dials the daemon and creates a new profile with the given
|
||||
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||
// when a daemon client is already available to reuse the connection.
|
||||
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||
// override the default profile filepath if provided
|
||||
if configPath != "" {
|
||||
@@ -241,10 +201,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
return connectClient.Run(config, nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
|
||||
@@ -264,34 +264,24 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||
client := internal.NewConnectClient(ctx, c.recorder)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// The supervisor owns the run; we wait until it is established, ends with a
|
||||
// startup error (permanent backoff err), or startCtx expires.
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
client.RunAsync(c.config, nil)
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||
// run loop to tear the engine down, so this cancel() is no longer
|
||||
// required to break the deadlock and could be removed. It is kept as a
|
||||
// defensive belt-and-suspenders: cancelling the parent context first
|
||||
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||
if err := client.WaitEstablishedOrDone(startCtx); err != nil {
|
||||
// Either startCtx expired while connecting, or the run ended before it
|
||||
// established. Cancel the client context before stopping: Engine.Start
|
||||
// blocks on the signal stream while holding the engine mutex and only
|
||||
// unblocks on cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
return fmt.Errorf("stop error after startup failure. Stop error: %w. Startup: %w", stopErr, err)
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -49,17 +49,23 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
// androidMobileDep is set on Android to inject the MobileDependency for runs
|
||||
// started through the generic entry points (Run/RunAsync, e.g. embed.Client).
|
||||
// nil on other platforms, where the dependency is empty.
|
||||
var androidMobileDep func(config *profilemanager.Config) MobileDependency
|
||||
|
||||
// mobileDependency returns the MobileDependency for a run started via the
|
||||
// generic entry points. On Android the androidMobileDep provider supplies
|
||||
// platform stubs (or real implementations); elsewhere it is empty.
|
||||
func (c *ConnectClient) mobileDependency(config *profilemanager.Config) MobileDependency {
|
||||
if androidMobileDep != nil {
|
||||
return androidMobileDep(config)
|
||||
}
|
||||
return MobileDependency{}
|
||||
}
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
runCancel context.CancelFunc
|
||||
runExited chan struct{}
|
||||
runOnce sync.Once
|
||||
runStarted atomic.Bool
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
engine *Engine
|
||||
@@ -68,41 +74,62 @@ type ConnectClient struct {
|
||||
updateManager *updater.Manager
|
||||
|
||||
persistSyncResponse bool
|
||||
|
||||
// sup serializes all start/stop requests so two lifecycle operations can
|
||||
// never overlap. See connect_lifecycle.go.
|
||||
sup *supervisor
|
||||
}
|
||||
|
||||
func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
) *ConnectClient {
|
||||
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||
// loop. runCancel is set once at construction, so Stop can call it without
|
||||
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
return &ConnectClient{
|
||||
ctx: runCtx,
|
||||
runCancel: runCancel,
|
||||
runExited: make(chan struct{}),
|
||||
config: config,
|
||||
c := &ConnectClient{
|
||||
ctx: ctx,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
c.sup = newSupervisor(ctx, c.run)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
// Run with main logic. md carries optional gRPC metadata (e.g. the UI
|
||||
// user-agent) to forward to the management/signal services; nil when none.
|
||||
func (c *ConnectClient) Run(config *profilemanager.Config, md metadata.MD, logPath string) error {
|
||||
return c.sup.start(config, md, c.mobileDependency(config), logPath)
|
||||
}
|
||||
|
||||
// RunAsync starts a client run without blocking. Used by the daemon and embed,
|
||||
// which drive the lifecycle through the supervisor rather than blocking on Run;
|
||||
// they then wait for the outcome via WaitEstablishedOrDone. The run's lifecycle
|
||||
// channels are created and owned by the supervisor — callers never hold them.
|
||||
func (c *ConnectClient) RunAsync(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.startAsync(config, md, c.mobileDependency(config), "", nil)
|
||||
}
|
||||
|
||||
// Restart atomically stops any in-flight run and starts a fresh one with the
|
||||
// given config. The stop+start happens as a single supervisor operation, so no
|
||||
// other lifecycle request can interleave between them — used for explicit
|
||||
// restarts (e.g. an MDM policy change) that must not expose a "stopped" window.
|
||||
func (c *ConnectClient) Restart(config *profilemanager.Config, md metadata.MD) {
|
||||
c.sup.restartAsync(config, md, c.mobileDependency(config), "")
|
||||
}
|
||||
|
||||
// WaitEstablishedOrDone blocks until the in-flight run becomes established (nil),
|
||||
// ends before that (the run error, or a sentinel on a clean stop), or ctx is
|
||||
// cancelled. Returns errNoRunInFlight if no run is in flight. Wraps the wait on
|
||||
// the supervisor-owned channels so callers never touch them directly.
|
||||
func (c *ConnectClient) WaitEstablishedOrDone(ctx context.Context) error {
|
||||
return c.sup.waitEstablishedOrDone(ctx)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
func (c *ConnectClient) RunOnAndroid(
|
||||
config *profilemanager.Config,
|
||||
tunAdapter device.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
@@ -121,10 +148,11 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
return c.sup.start(config, nil, mobileDependency, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
config *profilemanager.Config,
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
@@ -142,15 +170,12 @@ func (c *ConnectClient) RunOniOS(
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, logFilePath)
|
||||
return c.sup.start(config, nil, mobileDependency, logFilePath)
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||
// the loop to finish (and skip the wait if the loop never ran).
|
||||
c.runStarted.Store(true)
|
||||
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||
|
||||
// run executes a single client run. runCtx is owned by the supervisor: cancelling
|
||||
// it tears the run down (it is the parent of the per-attempt engine context).
|
||||
func (c *ConnectClient) run(runCtx context.Context, config *profilemanager.Config, mobileDependency MobileDependency, connEstablishedChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -214,18 +239,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}()
|
||||
|
||||
wrapErr := state.Wrap
|
||||
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
|
||||
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
|
||||
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if c.config.ManagementURL.Scheme == "https" {
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
|
||||
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -259,13 +284,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
if c.ctx.Err() != nil {
|
||||
if runCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
state.Set(StatusConnecting)
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
engineCtx, cancel := context.WithCancel(runCtx)
|
||||
defer func() {
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
@@ -273,8 +298,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
@@ -291,7 +316,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||
|
||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
defer func() {
|
||||
if err = mgmClient.Close(); err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
@@ -300,13 +325,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||
loginStarted := time.Now()
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, config)
|
||||
if err != nil {
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
c.runCancel()
|
||||
// No teardown needed: login fails before the engine is started
|
||||
// (engine.Start is below), so there is nothing running to stop.
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -360,7 +386,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -404,7 +430,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
@@ -412,12 +438,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
// The supervisor owns connEstablishedChan and it is always present. Guard
|
||||
// against a double close: operation re-runs on ErrResetConnection retries
|
||||
// within the same run, and the channel is closed only on the first connect.
|
||||
select {
|
||||
case <-connEstablishedChan:
|
||||
default:
|
||||
close(connEstablishedChan)
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
@@ -426,8 +453,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||
|
||||
// Always tear the engine down once its context is cancelled. engine.Stop
|
||||
// is nil-guarded per component, so calling it unconditionally is safe and
|
||||
// avoids both the data race on engine.wgInterface and skipping teardown
|
||||
// when the interface was never brought up (e.g. a mid-start failure).
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
@@ -445,12 +474,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// Login failed permanently: the engine was never started, so there
|
||||
// is nothing to tear down — just record that a login is needed.
|
||||
state.Set(StatusNeedsLogin)
|
||||
c.runCancel()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -471,6 +501,22 @@ func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
|
||||
return relayCfg.GetUrls(), token
|
||||
}
|
||||
|
||||
// ConnectionRunning reports whether a connection run is currently in flight
|
||||
// (connecting, connected, or reconnecting). Answered by the supervisor via a
|
||||
// serialized query, so it settles behind an in-flight stop. Distinct from
|
||||
// ServiceRunning, which reports whether the service itself is alive.
|
||||
func (c *ConnectClient) ConnectionRunning() bool {
|
||||
return c.sup.isRunning()
|
||||
}
|
||||
|
||||
// ServiceRunning reports whether the client's lifecycle supervisor is alive and
|
||||
// able to accept start/stop commands — i.e. its context has not been cancelled
|
||||
// (the daemon is not shutting down). Independent of whether a connection run is
|
||||
// up (that is ConnectionRunning).
|
||||
func (c *ConnectClient) ServiceRunning() bool {
|
||||
return c.sup.ctx.Err() == nil
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Engine() *Engine {
|
||||
if c == nil {
|
||||
return nil
|
||||
@@ -527,12 +573,10 @@ func (c *ConnectClient) Status() StatusType {
|
||||
return status
|
||||
}
|
||||
|
||||
// Stop serializes a stop request through the lifecycle supervisor and blocks
|
||||
// until the in-flight run is fully torn down.
|
||||
func (c *ConnectClient) Stop() error {
|
||||
c.runCancel()
|
||||
if c.runStarted.Load() {
|
||||
<-c.runExited
|
||||
}
|
||||
return nil
|
||||
return c.sup.stop()
|
||||
}
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
@@ -59,19 +60,17 @@ var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// Wire up the default MobileDependency provider so embed.Client.Start() works
|
||||
// on Android with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
// Applications that need P2P ICE or real DNS should replace this by setting
|
||||
// androidMobileDep before calling Start().
|
||||
androidMobileDep = func(config *profilemanager.Config) MobileDependency {
|
||||
return mobileDependencyForEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,23 +10,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
// mobileDependencyForEmbed builds the MobileDependency used by embed.Client on
|
||||
// Android so the engine's existing Android code paths work unchanged.
|
||||
func mobileDependencyForEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
) MobileDependency {
|
||||
return MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
|
||||
362
client/internal/connect_lifecycle.go
Normal file
362
client/internal/connect_lifecycle.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
// errAlreadyRunning is returned when a start is requested while a run is already
|
||||
// in flight.
|
||||
var errAlreadyRunning = errors.New("client is already running")
|
||||
|
||||
// errNoRunInFlight is returned by waitEstablishedOrDone when no run is active.
|
||||
var errNoRunInFlight = errors.New("no connection run in flight")
|
||||
|
||||
// errStoppedBeforeEstablished is returned when a run ended (cleanly) before the
|
||||
// connection was established.
|
||||
var errStoppedBeforeEstablished = errors.New("run stopped before the connection was established")
|
||||
|
||||
// lifecycleOp is a serialized lifecycle operation processed by the supervisor.
|
||||
type lifecycleOp int
|
||||
|
||||
const (
|
||||
opStart lifecycleOp = iota
|
||||
opStop
|
||||
opRestart
|
||||
opStatus
|
||||
opWaitEstablished
|
||||
)
|
||||
|
||||
// lifecycleCmd is a single lifecycle request handed to the supervisor goroutine.
|
||||
// They all flow through the same cmdCh so they are strictly ordered (FIFO) with
|
||||
// respect to each other.
|
||||
type lifecycleCmd struct {
|
||||
op lifecycleOp
|
||||
config *profilemanager.Config
|
||||
md metadata.MD
|
||||
mobileDep MobileDependency
|
||||
logPath string
|
||||
|
||||
// done is the caller's notification channel (nil for fire-and-forget). Its
|
||||
// meaning depends on op:
|
||||
// - opStart: receives the run's end result when the run terminates, or
|
||||
// errAlreadyRunning immediately if a run is already in flight.
|
||||
// - opStop: receives nil once the in-flight run has fully unwound.
|
||||
// - opWaitEstablished: receives the wait outcome (see waitEstablishedOrDone).
|
||||
done chan error
|
||||
|
||||
reply chan bool // opStatus only: receives whether a run is in flight
|
||||
waitCtx context.Context // opWaitEstablished only: the waiter's cancellation context
|
||||
}
|
||||
|
||||
// runState holds the lifecycle channels of a single in-flight run, owned by the
|
||||
// loop goroutine. It never escapes the supervisor as an API; the only readers
|
||||
// are the per-wait goroutines the loop spawns for opWaitEstablished.
|
||||
//
|
||||
// connEstablishedChan is closed by the run once the connection is established.
|
||||
// The supervisor creates and owns it — callers no longer supply it; they observe
|
||||
// it through waitEstablishedOrDone. ended is closed (broadcast) when the run
|
||||
// terminates, so any number of waiters can observe it; err is the run's end
|
||||
// result, valid only after ended is closed.
|
||||
type runState struct {
|
||||
connEstablishedChan chan struct{} // closed by the run on established
|
||||
ended chan struct{} // closed by finishRun when the run terminates
|
||||
err error // run end result, valid after ended is closed
|
||||
}
|
||||
|
||||
// runEndResult is sent by the run goroutine to the supervisor when a run ends,
|
||||
// whether on its own (error / external context cancellation) or because of a Stop.
|
||||
type runEndResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// runFunc executes a single client run bound to the supervisor-owned context,
|
||||
// with the config supplied by the start request.
|
||||
type runFunc func(ctx context.Context, config *profilemanager.Config, mobileDep MobileDependency, connEstablishedChan chan struct{}, logPath string) error
|
||||
|
||||
// supervisor serializes start/stop of a single client run. Every request goes
|
||||
// through cmdCh and is handled one at a time by the loop goroutine, so two
|
||||
// lifecycle operations can never overlap and their order is preserved (FIFO).
|
||||
// The loop goroutine is the sole owner of curStart/runCancel, so that state
|
||||
// needs no locking. The loop exits when the parent context is cancelled.
|
||||
type supervisor struct {
|
||||
ctx context.Context
|
||||
run runFunc
|
||||
cmdCh chan lifecycleCmd
|
||||
runEnded chan runEndResult
|
||||
|
||||
// owned exclusively by the loop goroutine. curStart is the in-flight start
|
||||
// command (nil = idle); its done channel is notified when the run ends.
|
||||
// curRun holds that run's lifecycle channels; runCancel cancels it.
|
||||
curStart *lifecycleCmd
|
||||
curRun *runState
|
||||
runCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newSupervisor(ctx context.Context, run runFunc) *supervisor {
|
||||
s := &supervisor{
|
||||
ctx: ctx,
|
||||
run: run,
|
||||
cmdCh: make(chan lifecycleCmd, 16),
|
||||
runEnded: make(chan runEndResult, 1),
|
||||
}
|
||||
go s.loop()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *supervisor) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
s.shutdown()
|
||||
return
|
||||
case cmd := <-s.cmdCh:
|
||||
switch cmd.op {
|
||||
case opStart:
|
||||
s.handleStart(cmd)
|
||||
case opStop:
|
||||
s.handleStop(cmd)
|
||||
case opRestart:
|
||||
s.handleRestart(cmd)
|
||||
case opStatus:
|
||||
cmd.reply <- (s.isRunningInternal())
|
||||
case opWaitEstablished:
|
||||
s.handleWaitEstablished(cmd)
|
||||
}
|
||||
case res := <-s.runEnded:
|
||||
// Run ended on its own, without an explicit Stop.
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
notify(cmd.done, errAlreadyRunning)
|
||||
return
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(s.ctx)
|
||||
if cmd.md != nil {
|
||||
// Carry caller-supplied gRPC metadata (e.g. UI user-agent) into the run
|
||||
// context so the engine's management/signal calls forward it. The cancel
|
||||
// still drives runCtx (metadata wrapping preserves cancellation).
|
||||
runCtx = metadata.NewOutgoingContext(runCtx, cmd.md)
|
||||
}
|
||||
s.runCancel = cancel
|
||||
s.curStart = &cmd
|
||||
s.curRun = &runState{connEstablishedChan: make(chan struct{}), ended: make(chan struct{})}
|
||||
|
||||
go func(ctx context.Context, cfg *profilemanager.Config, m MobileDependency, established chan struct{}, lp string) {
|
||||
err := s.run(ctx, cfg, m, established, lp)
|
||||
s.runEnded <- runEndResult{err: err}
|
||||
}(runCtx, cmd.config, cmd.mobileDep, s.curRun.connEstablishedChan, cmd.logPath)
|
||||
}
|
||||
|
||||
func (s *supervisor) handleStop(cmd lifecycleCmd) {
|
||||
if !s.isRunningInternal() {
|
||||
notify(cmd.done, nil)
|
||||
return
|
||||
}
|
||||
s.stopCurrentRun()
|
||||
notify(cmd.done, nil)
|
||||
}
|
||||
|
||||
// handleRestart tears down any in-flight run and starts a fresh one in a single
|
||||
// loop turn. No other command can interleave between the stop and the start
|
||||
// (the loop is single-threaded), so the swap is atomic without relying on any
|
||||
// daemon-side lock — that is what an explicit restart (e.g. MDM config change)
|
||||
// needs to avoid a window where the client is observably stopped.
|
||||
func (s *supervisor) handleRestart(cmd lifecycleCmd) {
|
||||
if s.isRunningInternal() {
|
||||
s.stopCurrentRun()
|
||||
}
|
||||
s.handleStart(cmd)
|
||||
}
|
||||
|
||||
// stopCurrentRun cancels the in-flight run and blocks the supervisor until it
|
||||
// has fully unwound, so the next action starts from a clean slate. The run
|
||||
// goroutine reports completion via runEnded. Caller must hold an in-flight run
|
||||
// (curStart != nil).
|
||||
func (s *supervisor) stopCurrentRun() {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
|
||||
// finishRun resets lifecycle state after a run terminates and hands the run
|
||||
// error back to whoever asked to be notified of the start.
|
||||
func (s *supervisor) finishRun(err error) {
|
||||
s.runCancel = nil
|
||||
if s.isRunningInternal() {
|
||||
// Publish the result to the broadcast channel before nil-ing curRun, so
|
||||
// any opWaitEstablished goroutines blocked on ended observe err.
|
||||
s.curRun.err = err
|
||||
close(s.curRun.ended)
|
||||
s.curRun = nil
|
||||
|
||||
notify(s.curStart.done, err)
|
||||
s.curStart = nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleWaitEstablished answers an opWaitEstablished request. The select itself
|
||||
// runs in a spawned goroutine on the run's channels so it never blocks the loop;
|
||||
// the loop only snapshots the in-flight run's channels (which it owns) here.
|
||||
func (s *supervisor) handleWaitEstablished(cmd lifecycleCmd) {
|
||||
caller := cmd.done
|
||||
if !s.isRunningInternal() {
|
||||
notify(caller, errNoRunInFlight)
|
||||
return
|
||||
}
|
||||
rs := s.curRun
|
||||
established := rs.connEstablishedChan
|
||||
ctx := cmd.waitCtx
|
||||
go func() {
|
||||
select {
|
||||
case <-established:
|
||||
notify(caller, nil)
|
||||
case <-rs.ended:
|
||||
if rs.err != nil {
|
||||
notify(caller, rs.err)
|
||||
return
|
||||
}
|
||||
notify(caller, errStoppedBeforeEstablished)
|
||||
case <-ctx.Done():
|
||||
notify(caller, ctx.Err())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdown tears down the in-flight run when the parent context is cancelled,
|
||||
// then fails any still-queued commands so their callers never hang.
|
||||
func (s *supervisor) shutdown() {
|
||||
if s.runCancel != nil {
|
||||
s.runCancel()
|
||||
res := <-s.runEnded
|
||||
s.finishRun(res.err)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case cmd := <-s.cmdCh:
|
||||
notify(cmd.done, s.ctx.Err())
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAsync enqueues a start without blocking. If done is non-nil it receives
|
||||
// the run's end result (or errAlreadyRunning on rejection, or the context error
|
||||
// on shutdown).
|
||||
func (s *supervisor) startAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string, done chan error) {
|
||||
cmd := lifecycleCmd{op: opStart, config: config, md: md, mobileDep: mobileDep, logPath: logPath, done: done}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
notify(done, s.ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// restartAsync enqueues an atomic stop+start without blocking. The supervisor
|
||||
// tears down any in-flight run and starts a fresh one with the supplied config
|
||||
// in a single loop turn (see handleRestart). Fire-and-forget: the new run owns
|
||||
// its lifecycle channels, observed via waitEstablishedOrDone.
|
||||
func (s *supervisor) restartAsync(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) {
|
||||
cmd := lifecycleCmd{op: opRestart, config: config, md: md, mobileDep: mobileDep, logPath: logPath}
|
||||
select {
|
||||
case s.cmdCh <- cmd:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// start enqueues a start and blocks until the run terminates, preserving the
|
||||
// blocking contract of the legacy Run entry points.
|
||||
func (s *supervisor) start(config *profilemanager.Config, md metadata.MD, mobileDep MobileDependency, logPath string) error {
|
||||
done := make(chan error, 1)
|
||||
s.startAsync(config, md, mobileDep, logPath, done)
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// isRunning asks the loop whether a run is in flight. The query is serialized
|
||||
// with start/stop, so during a stop it waits for the teardown to settle and
|
||||
// then reports the final state — never a transient "half-stopped".
|
||||
func (s *supervisor) isRunning() bool {
|
||||
reply := make(chan bool, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStatus, reply: reply}:
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
return r
|
||||
case <-s.ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *supervisor) isRunningInternal() bool {
|
||||
return s.curStart != nil
|
||||
}
|
||||
|
||||
// waitEstablishedOrDone blocks until the in-flight run becomes established
|
||||
// (returns nil) or ends before that (returns the run error, or
|
||||
// errStoppedBeforeEstablished on a clean stop), or ctx is cancelled. Returns
|
||||
// errNoRunInFlight if no run is in flight. The wait is performed by a goroutine
|
||||
// spawned inside the loop (see handleWaitEstablished); the run's channels never
|
||||
// leave the supervisor.
|
||||
func (s *supervisor) waitEstablishedOrDone(ctx context.Context) error {
|
||||
reply := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opWaitEstablished, waitCtx: ctx, done: reply}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-reply:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// stop enqueues a stop and blocks until the in-flight run is fully torn down.
|
||||
func (s *supervisor) stop() error {
|
||||
done := make(chan error, 1)
|
||||
select {
|
||||
case s.cmdCh <- lifecycleCmd{op: opStop, done: done}:
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends on a caller-supplied channel without blocking. The channel is
|
||||
// expected to be buffered (cap 1); a nil channel means the caller did not ask
|
||||
// to be notified.
|
||||
func notify(ch chan error, err error) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -207,35 +207,3 @@ func FormatAnswers(answers []dns.RR) string {
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||
// advertise EDNS0.
|
||||
func StripOPT(msg *dns.Msg) {
|
||||
if len(msg.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := msg.Extra[:0]
|
||||
for _, rr := range msg.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
msg.Extra = out
|
||||
}
|
||||
|
||||
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||
// the message, if present.
|
||||
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||
opt := msg.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||
return ede, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -120,42 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
StripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestExtractEDE(t *testing.T) {
|
||||
t.Run("no edns", func(t *testing.T) {
|
||||
_, ok := ExtractEDE(&dns.Msg{})
|
||||
assert.False(t, ok, "message without OPT has no EDE")
|
||||
})
|
||||
|
||||
t.Run("edns without ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
rm.SetEdns0(4096, false)
|
||||
_, ok := ExtractEDE(rm)
|
||||
assert.False(t, ok, "OPT without EDE option returns false")
|
||||
})
|
||||
|
||||
t.Run("with ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||
rm.Extra = append(rm.Extra, opt)
|
||||
|
||||
ede, ok := ExtractEDE(rm)
|
||||
assert.True(t, ok, "EDE option should be found")
|
||||
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -39,15 +38,11 @@ const (
|
||||
// defaultWarningDelayBase is the starting grace window before a
|
||||
// "Nameserver group unreachable" event fires for a group that's
|
||||
// never been healthy and only has overlay upstreams with no
|
||||
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||
// see warningDelay.
|
||||
defaultWarningDelayBase = 60 * time.Second
|
||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
||||
defaultWarningDelayBase = 30 * time.Second
|
||||
// warningDelayBonusCap caps the route-count bonus added to the
|
||||
// base grace window. See warningDelay.
|
||||
// base grace window. See warningDelayFor.
|
||||
warningDelayBonusCap = 30 * time.Second
|
||||
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||
)
|
||||
|
||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||
@@ -140,7 +135,7 @@ type DefaultServer struct {
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxHandlers []handlerWrapper
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@@ -204,6 +199,8 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||
type DefaultServerConfig struct {
|
||||
WgInterface WGIface
|
||||
@@ -292,6 +289,7 @@ func newDefaultServer(
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -300,7 +298,7 @@ func newDefaultServer(
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
warningDelayBase: warningDelayBaseFromEnv(),
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
@@ -330,7 +328,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
type routeSettable interface {
|
||||
setSelectedRoutes(func() route.HAMap)
|
||||
}
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
if h, ok := entry.handler.(routeSettable); ok {
|
||||
h.setSelectedRoutes(selected)
|
||||
}
|
||||
@@ -980,23 +978,19 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxHandlers {
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
// The local resolver is a persistent singleton shared by every custom
|
||||
// zone and reused across config updates. Its chain registrations are
|
||||
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||
// context (breaking external CNAME-target resolution) and clears its
|
||||
// records, so it must not be torn down here.
|
||||
if existing.handler != s.localResolver {
|
||||
existing.handler.Stop()
|
||||
}
|
||||
existing.handler.Stop()
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
s.dnsMuxHandlers = muxUpdates
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||
@@ -1160,26 +1154,6 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
||||
return false
|
||||
}
|
||||
|
||||
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||
// non-positive values fall back to defaultWarningDelayBase.
|
||||
func warningDelayBaseFromEnv() time.Duration {
|
||||
val := os.Getenv(envWarningDelay)
|
||||
if val == "" {
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
d, err := time.ParseDuration(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
if d <= 0 {
|
||||
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||
return defaultWarningDelayBase
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// warningDelay returns the grace window for the given selected-route
|
||||
// count. Scales gently: +1s per 100 routes, capped by
|
||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||
@@ -1230,7 +1204,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
||||
// in more than one handler.
|
||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||
for _, entry := range s.dnsMuxHandlers {
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||
if !ok {
|
||||
continue
|
||||
|
||||
@@ -104,6 +104,19 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -119,20 +132,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -154,17 +169,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
@@ -173,10 +191,10 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
@@ -197,13 +215,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
@@ -212,7 +232,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
@@ -220,7 +240,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -242,7 +262,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -264,7 +284,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
@@ -281,41 +301,42 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
@@ -372,7 +393,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
@@ -384,20 +405,14 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,8 +512,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
@@ -1014,15 +1029,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := []handlerWrapper{
|
||||
{
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
@@ -1031,15 +1046,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := []handlerWrapper{
|
||||
{
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
"upstream-root2": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
@@ -1048,22 +1063,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := []handlerWrapper{
|
||||
{
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
},
|
||||
{
|
||||
"upstream-other": {
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
@@ -1074,7 +1089,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers []handlerWrapper
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
@@ -1358,38 +1373,32 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxHandlers: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
var found *handlerWrapper
|
||||
for i := range server.dnsMuxHandlers {
|
||||
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||
found = &server.dnsMuxHandlers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||
if found != nil {
|
||||
assert.Equal(t, expectedDomain, found.domain,
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for _, entry := range server.dnsMuxHandlers {
|
||||
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@@ -1404,7 +1413,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxHandlers {
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
@@ -1413,108 +1422,12 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxHandlers")
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||
// for the given fqdn pattern at the given priority.
|
||||
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||
for _, h := range s.handlerChain.handlers {
|
||||
if h.OrigPattern == pattern && h.Priority == priority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||
// tracks each (handler, domain) registration independently when one handler
|
||||
// serves multiple zones. Every custom zone is served by the same handler
|
||||
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||
// zones onto one entry, leaving removed zones in the chain to answer
|
||||
// authoritatively with no records.
|
||||
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||
// One handler serves every custom zone, mirroring s.localResolver.
|
||||
shared := &mockHandler{Id: "local-resolver"}
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Two custom zones under the same handler. The surviving zone is registered
|
||||
// last, mirroring the management emission order.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test should be registered after the first update")
|
||||
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should be registered after the first update")
|
||||
|
||||
// Remove one zone, keep the other.
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||
"peerzone.test should remain after removing userzone.test")
|
||||
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||
}
|
||||
|
||||
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||
// does not tear down the shared local resolver during reconfiguration. The
|
||||
// resolver is a process-lifetime singleton reused across config updates;
|
||||
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||
// resolution) and clears its records. updateMux must deregister its chain
|
||||
// entries without stopping it. Records surviving a teardown update is the
|
||||
// observable proxy: Stop() would have cleared them.
|
||||
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||
resolver := local.NewResolver()
|
||||
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||
Name: "peer.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}))
|
||||
|
||||
server := &DefaultServer{
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
localResolver: resolver,
|
||||
}
|
||||
|
||||
server.updateMux([]handlerWrapper{
|
||||
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||
})
|
||||
|
||||
// Remove the zone. The resolver must survive so its records and lookup
|
||||
// context stay intact for the next registration.
|
||||
server.updateMux(nil)
|
||||
|
||||
var response *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
response = m
|
||||
return nil
|
||||
},
|
||||
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||
|
||||
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2136,6 +2049,7 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
@@ -2293,7 +2207,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
@@ -2369,11 +2283,12 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
@@ -2480,6 +2395,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
@@ -2491,7 +2407,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2528,6 +2444,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
@@ -2542,7 +2459,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2567,32 +2484,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set bool
|
||||
val string
|
||||
want time.Duration
|
||||
}{
|
||||
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envWarningDelay, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(envWarningDelay)
|
||||
}
|
||||
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
@@ -2704,6 +2595,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
@@ -2721,7 +2613,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
@@ -2748,6 +2640,7 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -443,32 +443,29 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||
u.markUpstreamOk(upstream)
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||
// refused zones, transient recursion errors), not reachability
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(rm)
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(rm)
|
||||
stripOPT(rm)
|
||||
}
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
@@ -523,6 +520,22 @@ func upstreamUDPSize() uint16 {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
|
||||
@@ -517,78 +517,6 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||
// those are per-question outcomes from a reachable server and must not mark
|
||||
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respA mockUpstreamResponse
|
||||
respB mockUpstreamResponse
|
||||
wantHealthy bool
|
||||
}{
|
||||
{
|
||||
name: "both SERVFAIL are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "both REFUSED are reachable",
|
||||
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||
wantHealthy: true,
|
||||
},
|
||||
{
|
||||
name: "timeout marks unhealthy",
|
||||
respA: mockUpstreamResponse{err: timeoutErr},
|
||||
respB: mockUpstreamResponse{err: timeoutErr},
|
||||
wantHealthy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): tc.respA,
|
||||
b.String(): tc.respB,
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a, b})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||
if tc.wantHealthy {
|
||||
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||
} else {
|
||||
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -985,6 +913,19 @@ func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
@@ -26,15 +26,6 @@ import (
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||
// client can see the reason without inspecting this peer's logs. They live in
|
||||
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||
const (
|
||||
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||
edeNetbirdUpstreamFailure uint16 = 49153
|
||||
)
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
@@ -229,7 +220,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -342,7 +333,6 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
@@ -384,10 +374,6 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
@@ -428,33 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return edeNetbirdUpstreamTimeout
|
||||
}
|
||||
return edeNetbirdUpstreamFailure
|
||||
}
|
||||
|
||||
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||
// It deliberately omits the upstream server address, which may be an internal
|
||||
// resolver and is exposed to any client permitted to use the route; the full
|
||||
// detail stays in the forwarder's local log.
|
||||
func edeText(dnsErr *net.DNSError) string {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return "netbird forwarder: upstream timeout"
|
||||
}
|
||||
return "netbird forwarder: upstream failure"
|
||||
}
|
||||
|
||||
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||
// creating the OPT pseudo-record if the response does not already carry one.
|
||||
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||
opt := resp.IsEdns0()
|
||||
if opt == nil {
|
||||
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||
opt = resp.IsEdns0()
|
||||
}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -618,85 +617,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lookupErr error
|
||||
reqEdns bool
|
||||
wantEDE bool
|
||||
wantCode uint16
|
||||
wantTextHas string
|
||||
}{
|
||||
{
|
||||
name: "timeout with edns0",
|
||||
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamTimeout,
|
||||
wantTextHas: "netbird forwarder: upstream timeout",
|
||||
},
|
||||
{
|
||||
name: "server failure with edns0",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamFailure,
|
||||
wantTextHas: "netbird forwarder: upstream failure",
|
||||
},
|
||||
{
|
||||
name: "no edns0 in request omits ede",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: false,
|
||||
wantEDE: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
if tt.reqEdns {
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
mockResolver.AssertExpectations(t)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||
|
||||
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||
if !tt.wantEDE {
|
||||
assert.False(t, ok, "response must not carry EDE")
|
||||
return
|
||||
}
|
||||
require.True(t, ok, "response must carry EDE")
|
||||
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -86,8 +88,6 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -201,8 +201,6 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
started bool
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
@@ -283,15 +281,9 @@ func NewEngine(
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
// The engine is single-use: a fresh instance is built per connection
|
||||
// cycle (see Client.run), so the run context is created once here rather
|
||||
// than in Start.
|
||||
ctx, cancel := context.WithCancel(clientCtx)
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
@@ -324,34 +316,8 @@ func (e *Engine) Stop() error {
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.cancel()
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
e.stopLocked()
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked tears down everything Start may have brought up, in the order
|
||||
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||
func (e *Engine) stopLocked() {
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
@@ -402,6 +368,10 @@ func (e *Engine) stopLocked() {
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
@@ -420,6 +390,21 @@ func (e *Engine) stopLocked() {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
@@ -457,38 +442,18 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// The engine is single-use. Reject a duplicate start and a start on an
|
||||
// already-stopped engine (run context cancelled).
|
||||
if e.started {
|
||||
return ErrEngineAlreadyStarted
|
||||
}
|
||||
|
||||
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||
}
|
||||
|
||||
e.started = true
|
||||
|
||||
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||
// run context first so goroutines started before the failure (connMgr,
|
||||
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||
// not just what close() covers.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.cancel()
|
||||
e.stopLocked()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
@@ -522,11 +487,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -561,6 +528,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -569,6 +537,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -580,6 +549,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -604,7 +574,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
if err := e.dnsServer.Initialize(); err != nil {
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -616,9 +588,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
if err = e.receiveSignalEvents(); err != nil {
|
||||
return err
|
||||
}
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
@@ -670,6 +640,7 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
@@ -1156,6 +1127,20 @@ func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
}
|
||||
|
||||
// wrapDisconnectError classifies a receive-loop failure before the run is torn
|
||||
// down. An auth rejection (PermissionDenied/Unauthenticated) means the session
|
||||
// needs re-login and retrying is futile, so mark it terminal (NeedsLogin) — run()
|
||||
// then exits on its own instead of spinning the backoff. Any other failure is a
|
||||
// recoverable connection reset that the backoff should retry.
|
||||
func (e *Engine) wrapDisconnectError(err error) {
|
||||
state := CtxGetState(e.ctx)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied || s.Code() == codes.Unauthenticated) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
return
|
||||
}
|
||||
_ = state.Wrap(ErrResetConnection)
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
@@ -1182,9 +1167,9 @@ func (e *Engine) receiveJobEvents() {
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1266,9 +1251,9 @@ func (e *Engine) receiveManagementEvents() {
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if management is unavailable for a long time, or rejects
|
||||
// us (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
@@ -1729,7 +1714,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() error {
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
@@ -1792,20 +1777,15 @@ func (e *Engine) receiveSignalEvents() error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// happens if signal is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
// happens if signal is unavailable for a long time, or rejects us
|
||||
// (auth). wrapDisconnectError decides retry vs needs-login.
|
||||
e.wrapDisconnectError(err)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||
e.signal.WaitStreamConnected(e.ctx)
|
||||
if err := e.ctx.Err(); err != nil {
|
||||
return fmt.Errorf("wait for signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
e.signal.WaitStreamConnected()
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
|
||||
@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
|
||||
@@ -251,14 +251,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||
// the original client did not request EDNS0.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
if !hadEdns {
|
||||
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
@@ -268,13 +260,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||
}
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(reply)
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
|
||||
@@ -171,13 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
// Persist the latest sync response so DebugBundle can include the network
|
||||
// map. On iOS this is backed by disk to keep it out of the constrained
|
||||
// process memory (see the syncstore package).
|
||||
connectClient.SetSyncResponsePersistence(true)
|
||||
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
return connectClient.RunOniOS(cfg, fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -36,7 +36,6 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
@@ -52,19 +51,8 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||
// process (decoupled from the network extension), so without this the server
|
||||
// lingers after the user dismisses the browser and the next connect stalls
|
||||
// trying to bind the same port.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
@@ -72,24 +60,12 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||
// safe to call when no login is running.
|
||||
func (a *Auth) Stop() {
|
||||
if a.cancel != nil {
|
||||
a.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
|
||||
@@ -344,9 +344,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
||||
|
||||
@@ -5,7 +5,6 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/pprof"
|
||||
|
||||
@@ -28,11 +27,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}
|
||||
|
||||
var clientMetrics debug.MetricsExporter
|
||||
if s.connectClient != nil {
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
}
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
if cm := engine.GetClientMetrics(); cm != nil {
|
||||
clientMetrics = cm
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +45,10 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
defer s.cleanupBundleCapture()
|
||||
|
||||
var refreshStatus func()
|
||||
if s.connectClient != nil {
|
||||
engine := s.connectClient.Engine()
|
||||
if engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
if engine := s.connectClient.Engine(); engine != nil {
|
||||
refreshStatus = func() {
|
||||
log.Debug("refreshing system health status for debug bundle")
|
||||
engine.RunHealthProbes(true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +112,7 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
s.connectClient.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
@@ -134,20 +126,13 @@ func (s *Server) SetSyncResponsePersistence(_ context.Context, req *proto.SetSyn
|
||||
|
||||
enabled := req.GetEnabled()
|
||||
s.persistSyncResponse = enabled
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
}
|
||||
s.connectClient.SetSyncResponsePersistence(enabled)
|
||||
|
||||
return &proto.SetSyncResponsePersistenceResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
cClient := s.connectClient
|
||||
if cClient == nil {
|
||||
return nil, errors.New("connect client is not initialized")
|
||||
}
|
||||
|
||||
return cClient.GetLatestSyncResponse()
|
||||
return s.connectClient.GetLatestSyncResponse()
|
||||
}
|
||||
|
||||
// StartCPUProfile starts CPU profiling in the daemon.
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -39,12 +38,11 @@ type conflictCheck struct {
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// 1. Stop the in-flight run via the supervisor (blocks until fully torn down).
|
||||
// 2. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// 3. Start a fresh run with the new config.
|
||||
// 4. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
@@ -52,39 +50,24 @@ type conflictCheck struct {
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
// Hold s.mutex for the entire restart sequence (stop + re-start). Any
|
||||
// concurrent Up/Down/Status arriving while MDM is restarting blocks on the
|
||||
// Lock until we are done — they then observe the post-restart state coherently.
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
// No run in flight, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel daemon-side login/status activities tied to the old run; the run
|
||||
// itself is torn down atomically by the supervisor inside Restart (see
|
||||
// restartEngineForMDMLocked), which stops and re-starts in one operation.
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
@@ -131,14 +114,13 @@ func (s *Server) publishConfigChangedEvent(source string) {
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
// (re-running applyMDMPolicy via Config.apply) and starts a fresh run.
|
||||
// Mirrors the tail of Server.Start so a runtime MDM change behaves
|
||||
// identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
// for the entire restart sequence so concurrent Up/Down/Status RPCs
|
||||
// observe a coherent post-restart state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
@@ -154,13 +136,13 @@ func (s *Server) restartEngineForMDMLocked() error {
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
log.Info("MDM restart: atomically restarting the run with re-resolved config")
|
||||
// MDM restart has no incoming RPC metadata; fire and forget. Restart is a
|
||||
// single supervisor op (atomic stop+start), so there is no observable
|
||||
// "stopped" window between tearing down the old run and starting the new.
|
||||
s.connectClient.Restart(config, nil)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -34,10 +34,6 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -147,10 +143,6 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
@@ -199,10 +191,6 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
|
||||
@@ -8,12 +8,10 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -39,15 +37,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
probeThreshold = time.Second * 5
|
||||
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
|
||||
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
|
||||
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
|
||||
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
|
||||
defaultInitialRetryTime = 30 * time.Minute
|
||||
defaultMaxRetryInterval = 60 * time.Minute
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
probeThreshold = time.Second * 5
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
@@ -72,15 +62,8 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
// Run state (in-flight? established/done channels?) is owned entirely by the
|
||||
// supervisor inside connectClient — the daemon keeps no per-run fields.
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -136,6 +119,13 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
// The ConnectClient is daemon-lifetime: build it exactly once, here. Its
|
||||
// supervisor lives as long as the daemon; Up/Down/MDM and reconnects all
|
||||
// drive this same instance. updateManager isn't ready yet (created in
|
||||
// Start) and is injected there via SetUpdateManager.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
s.startSleepDetector()
|
||||
@@ -147,7 +137,7 @@ func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.clientRunning {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,6 +155,7 @@ func (s *Server) Start() error {
|
||||
stateMgr := statemanager.New(s.profileManager.GetStatePath())
|
||||
s.updateManager = updater.NewManager(s.statusRecorder, stateMgr)
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
s.connectClient.SetUpdateManager(s.updateManager)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
@@ -190,7 +181,9 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
// actCancel cancels in-flight foreground operations (login/status); the run
|
||||
// itself is owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
// copy old default config
|
||||
@@ -232,99 +225,14 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
// Boot autoconnect: no incoming RPC metadata. The supervisor runs the
|
||||
// client and reconnects internally; we just fire and forget (the run owns
|
||||
// its established/done channels).
|
||||
s.connectClient.RunAsync(config, nil)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
|
||||
log.Debugf("run client connection exited with error: %v", err)
|
||||
}
|
||||
log.Tracef("client connection exited")
|
||||
return
|
||||
}
|
||||
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
go func() {
|
||||
t := time.NewTicker(24 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
mgmtState := statusRecorder.GetManagementState()
|
||||
signalState := statusRecorder.GetSignalState()
|
||||
if mgmtState.Connected && signalState.Connected {
|
||||
log.Tracef("resetting status")
|
||||
backOff.Reset()
|
||||
} else {
|
||||
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
runOperation := func() error {
|
||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
||||
if err != nil {
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("client connection exited gracefully, do not need to retry")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||
@@ -720,13 +628,22 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
|
||||
// The client (and its supervisor) is built once in New(), so a nil here
|
||||
// never happens in production — Up is only reachable after New() has run and
|
||||
// the gRPC server is serving. The real case this guards is the daemon
|
||||
// SHUTTING DOWN: rootCtx is cancelled, the supervisor is no longer accepting
|
||||
// commands, so ServiceRunning() is false even though the client exists. Bail
|
||||
// loud instead of enqueuing a run that will never start. (nil only happens in
|
||||
// tests that build a Server without New(); ServiceRunning is nil-safe.)
|
||||
if !s.connectClient.ServiceRunning() {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("service is not running, start the netbird service for 'up' to take effect")
|
||||
}
|
||||
|
||||
// If a connection run is already in flight, the existing engine is on the
|
||||
// job — just wait for it. Otherwise fall through to start a fresh run.
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -764,14 +681,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
// actCancel cancels in-flight foreground ops (login/status); the run is
|
||||
// owned by the supervisor and stopped via Stop, not this cancel.
|
||||
_, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
// Forward the caller's gRPC metadata (e.g. UI user-agent) into the run.
|
||||
md, _ := metadata.FromIncomingContext(callerCtx)
|
||||
|
||||
if s.config == nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("config is not defined, please call login command first")
|
||||
@@ -812,35 +729,26 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.connectClient.RunAsync(s.config, md)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
// todo: handle potential race conditions
|
||||
// waitForUp blocks until the in-flight run becomes established (success) or ends
|
||||
// before that (failure). The wait is owned by the supervisor (via the client) —
|
||||
// the daemon holds no per-run state here.
|
||||
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return nil, fmt.Errorf("client gave up to connect")
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
if err := s.connectClient.WaitEstablishedOrDone(timeoutCtx); err != nil {
|
||||
log.Debugf("waiting for the connection to be established failed: %v", err)
|
||||
return nil, fmt.Errorf("connection not established: %w", err)
|
||||
}
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
@@ -935,11 +843,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
// Down engine work in the daemon.
|
||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
giveUpChan := s.clientGiveUpChan
|
||||
|
||||
// cleanupConnection stops the run through the supervisor, which blocks until
|
||||
// the run has fully unwound — no separate goroutine-quiescence wait needed.
|
||||
if err := s.cleanupConnection(); err != nil {
|
||||
s.mutex.Unlock()
|
||||
// todo review to update the status in case any type of error
|
||||
log.Errorf("failed to shut down properly: %v", err)
|
||||
return nil, err
|
||||
@@ -948,20 +856,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
|
||||
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
|
||||
// The giveUpChan is closed at the end of connectWithRetryRuns.
|
||||
if giveUpChan != nil {
|
||||
select {
|
||||
case <-giveUpChan:
|
||||
log.Debugf("client goroutine finished successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -972,38 +866,19 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
// to skip the engine shutdown entirely.
|
||||
var engine *internal.Engine
|
||||
if s.connectClient != nil {
|
||||
engine = s.connectClient.Engine()
|
||||
// Tear the client down through the lifecycle supervisor BEFORE cancelling
|
||||
// the retry context. Stop serializes on the supervisor queue and blocks
|
||||
// until the in-flight run has fully unwound (a clean, synchronous teardown).
|
||||
// It must run before actCancel: cancelling the context first would make
|
||||
// Stop observe a dead context and return early without waiting.
|
||||
if err := s.connectClient.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop the retry goroutine so it does not start a fresh run. The client
|
||||
// itself is daemon-lifetime and intentionally kept (a later Up reuses it).
|
||||
s.actCancel()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||
// making the run loop the sole owner of engine shutdown.
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.connectClient = nil
|
||||
s.isSessionActive.Store(false)
|
||||
|
||||
log.Infof("service is down")
|
||||
@@ -1138,7 +1013,7 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
|
||||
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient.ConnectionRunning() {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
@@ -1184,48 +1059,13 @@ func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-s.clientRunningChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// A run that hits a terminal auth failure now exits on its own (engine marks
|
||||
// NeedsLogin), so we no longer poll-and-cancel: we wait for the in-flight run
|
||||
// to become established or to end. With no run in flight this returns
|
||||
// immediately (errNoRunInFlight); either way we then report the status below.
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady {
|
||||
if err := s.connectClient.WaitEstablishedOrDone(ctx); err != nil && ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1263,10 +1103,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
@@ -1304,10 +1140,6 @@ func (s *Server) GetPeerSSHHostKey(
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
@@ -1474,17 +1306,13 @@ func (s *Server) WaitJWTToken(
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy.
|
||||
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
|
||||
s.mutex.Lock()
|
||||
if !s.clientRunning {
|
||||
if !s.connectClient.ConnectionRunning() {
|
||||
s.mutex.Unlock()
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
|
||||
}
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
@@ -1538,10 +1366,6 @@ func isUnixRunningDesktop() bool {
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
@@ -1820,22 +1644,6 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
|
||||
log.Tracef("running client connection")
|
||||
client := internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
client.SetUpdateManager(s.updateManager)
|
||||
client.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
s.mutex.Lock()
|
||||
s.connectClient = client
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := client.Run(runningChan, s.logFile); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
@@ -1897,45 +1705,6 @@ func (s *Server) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
|
||||
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
|
||||
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
|
||||
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
|
||||
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
|
||||
multiplier := defaultRetryMultiplier
|
||||
|
||||
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
|
||||
// parse the multiplier from the environment variable string value to float64
|
||||
value, err := strconv.ParseFloat(envValue, 64)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
|
||||
} else {
|
||||
multiplier = value
|
||||
}
|
||||
}
|
||||
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: initialInterval,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: multiplier,
|
||||
MaxInterval: maxInterval,
|
||||
MaxElapsedTime: maxElapsedTime, // 14 days
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// parseEnvDuration parses the environment variable and returns the duration
|
||||
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
|
||||
if envValue := os.Getenv(envVar); envValue != "" {
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
return duration
|
||||
}
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
|
||||
}
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -15,14 +15,19 @@ import (
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
return &Server{
|
||||
rootCtx: context.Background(),
|
||||
ctx := context.Background()
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
// Honor the production invariant: the daemon-lifetime client always exists
|
||||
// (built in New). Server methods rely on s.connectClient being non-nil.
|
||||
s.connectClient = internal.NewConnectClient(ctx, s.statusRecorder)
|
||||
return s
|
||||
}
|
||||
|
||||
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
|
||||
return internal.NewConnectClient(ctx, nil, nil)
|
||||
return internal.NewConnectClient(ctx, nil)
|
||||
}
|
||||
|
||||
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
|
||||
@@ -87,41 +92,36 @@ func TestConcurrentConnectClientAccess(t *testing.T) {
|
||||
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
|
||||
}
|
||||
|
||||
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
|
||||
// properly nils out connectClient.
|
||||
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
// TestCleanupConnection_KeepsClientStopsRunning validates that cleanupConnection
|
||||
// clears the daemon "up" intent but KEEPS the daemon-lifetime ConnectClient
|
||||
// (it is reused across Up/Down; only the run is stopped).
|
||||
func TestCleanupConnection_KeepsClientStopsRunning(t *testing.T) {
|
||||
s := newTestServer()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
s.clientRunning = true
|
||||
|
||||
err := s.cleanupConnection()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
assert.NotNil(t, s.connectClient, "connectClient is daemon-lifetime and must persist after cleanup")
|
||||
assert.False(t, s.connectClient.ConnectionRunning(), "no run should be in flight after cleanup")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestCleanState_NilConnectClient(t *testing.T) {
|
||||
// TestCleanState_NotConnected validates that CleanState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestCleanState_NotConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil // will cause error if it tries to proceed past the nil check
|
||||
s.profileManager = nil // will cause error if it tries to proceed
|
||||
|
||||
// Should not panic — the nil check should prevent calling Status() on nil
|
||||
assert.NotPanics(t, func() {
|
||||
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
|
||||
})
|
||||
}
|
||||
|
||||
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
|
||||
// when connectClient is nil.
|
||||
func TestDeleteState_NilConnectClient(t *testing.T) {
|
||||
// TestDeleteState_NotConnected validates that DeleteState doesn't panic when no
|
||||
// connection run is in flight.
|
||||
func TestDeleteState_NotConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.connectClient = nil
|
||||
s.profileManager = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
@@ -129,60 +129,6 @@ func TestDeleteState_NilConnectClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestDownThenUp_StaleRunningChan documents the known state issue where
|
||||
// clientRunningChan from a previous connection is already closed, causing
|
||||
// waitForUp() to return immediately on reconnect.
|
||||
func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Simulate state after a successful connection
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
close(s.clientRunningChan) // closed when engine started
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
s.connectClient = newDummyConnectClient(context.Background())
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
waitDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.waitForUp(ctx)
|
||||
waitDone <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-waitDone:
|
||||
assert.NoError(t, err, "waitForUp returns success on stale channel")
|
||||
// But connectClient is still nil — this is the stale state issue
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
|
||||
s.mutex.Unlock()
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
|
||||
// ConnectClient has nil Engine (before Run is called).
|
||||
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -61,65 +60,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -38,7 +37,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
|
||||
|
||||
// CleanState handles cleaning of states (performing cleanup operations)
|
||||
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
@@ -81,7 +80,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
|
||||
|
||||
// DeleteState handles deletion of states without cleanup
|
||||
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
|
||||
if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
|
||||
if s.connectClient.ConnectionRunning() {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
|
||||
@@ -62,10 +62,6 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
}
|
||||
|
||||
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, nil, fmt.Errorf("engine not initialized")
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newAdminCommands creates the admin command tree with combined-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
|
||||
mgmtConfig, err := cfg.ToManagementConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create management config: %w", err)
|
||||
}
|
||||
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, cfg)
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newAdminCommands())
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
|
||||
63
combined/cmd/token.go
Normal file
63
combined/cmd/token.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
56
docker/build-env/README.md
Normal file
56
docker/build-env/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Build environments
|
||||
|
||||
Dockerfiles that pin the same toolchain CI uses, so a developer can
|
||||
reproduce a CI build locally without installing platform SDKs on their
|
||||
workstation. The version pins in each `Dockerfile` must stay in lockstep
|
||||
with `.github/workflows/`.
|
||||
|
||||
## `android/`
|
||||
|
||||
Mirrors `.github/workflows/mobile-build-validation.yml` (`android_build`
|
||||
job). Carries Go 1.25.5, Adopt JDK 11, Android cmdline-tools 8512546,
|
||||
NDK 23.1.7779620 and gomobile pinned at the CI commit. Use it to
|
||||
produce `netbird.aar` from `./client/android`:
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-android docker/build-env/android
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
gomobile bind \
|
||||
-o netbird.aar \
|
||||
-javapkg=io.netbird.gomobile \
|
||||
-ldflags="-checklinkname=0 \
|
||||
-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
-X github.com/netbirdio/netbird/version.version=local" \
|
||||
./client/android
|
||||
```
|
||||
|
||||
To build the full Android APK, bind-mount the `android-client` repo as
|
||||
well and run its own `./gradlew assembleDebug` from inside the
|
||||
container (the gradle wrapper ships with `android-client`).
|
||||
|
||||
## `windows-cross/`
|
||||
|
||||
Cross-compiles Windows binaries from Linux using `mingw-w64`. Lets you
|
||||
verify that `GOOS=windows go build ./...` compiles cleanly without
|
||||
needing a Windows VM. Cannot run Windows tests — the `golang-test-windows`
|
||||
CI job executes on a native `windows-latest` runner with wintun.dll
|
||||
and PsExec, neither of which lives under Linux containers.
|
||||
|
||||
```bash
|
||||
docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
```
|
||||
|
||||
## What is NOT here
|
||||
|
||||
- **iOS / macOS**: cannot legally run macOS in Docker (Apple EULA),
|
||||
and Xcode is not redistributable. The `ios_build` CI job uses a
|
||||
`macos-latest` GitHub runner; locally you need a real Mac.
|
||||
|
||||
- **Native Windows tests**: see note above. The Linux+mingw image
|
||||
builds, it does not execute Windows-host code paths
|
||||
(registry, wintun, services, PsExec workflows).
|
||||
|
||||
When CI version pins change, update the corresponding `ARG` lines in
|
||||
the Dockerfiles and the README's table of versions.
|
||||
86
docker/build-env/android/Dockerfile
Normal file
86
docker/build-env/android/Dockerfile
Normal file
@@ -0,0 +1,86 @@
|
||||
# Android build environment.
|
||||
#
|
||||
# Mirrors the toolchain pinned by .github/workflows/mobile-build-validation.yml
|
||||
# so a `gomobile bind` against ./client/android in this image produces the
|
||||
# same netbird.aar that CI builds.
|
||||
#
|
||||
# Tooling versions (must stay in sync with the CI workflow):
|
||||
# - Ubuntu 22.04 (matches the ubuntu-latest GitHub runner)
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - Adopt JDK 11 (matches actions/setup-java@v3 java-version: 11, distribution: adopt)
|
||||
# - Android SDK cmdline-tools 8512546
|
||||
# - Android NDK 23.1.7779620
|
||||
# - gomobile commit v0.0.0-20251113184115-a159579294ab
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-android docker/build-env/android
|
||||
#
|
||||
# # bind the netbird checkout in and run the same gomobile command CI runs
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-android \
|
||||
# gomobile bind \
|
||||
# -o netbird.aar \
|
||||
# -javapkg=io.netbird.gomobile \
|
||||
# -ldflags="-checklinkname=0 \
|
||||
# -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard \
|
||||
# -X github.com/netbirdio/netbird/version.version=local" \
|
||||
# ./client/android
|
||||
#
|
||||
# To build the full APK, mount the android-client repo too and run
|
||||
# `./gradlew assembleDebug` from /android-client (this image carries
|
||||
# gradle's prerequisites JDK + Android SDK but not the gradle wrapper —
|
||||
# that ships with android-client).
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Versions — bump in lockstep with .github/workflows/mobile-build-validation.yml.
|
||||
ARG GO_VERSION=1.25.5
|
||||
ARG ANDROID_CMDLINE_TOOLS_VERSION=8512546
|
||||
ARG ANDROID_NDK_VERSION=23.1.7779620
|
||||
ARG GOMOBILE_VERSION=v0.0.0-20251113184115-a159579294ab
|
||||
|
||||
ENV ANDROID_HOME=/opt/android-sdk
|
||||
ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${ANDROID_NDK_VERSION}
|
||||
ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV CGO_ENABLED=0
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${ANDROID_HOME}/cmdline-tools/latest/bin:${ANDROID_HOME}/platform-tools:${JAVA_HOME}/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
unzip \
|
||||
git \
|
||||
openjdk-11-jdk-headless \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod). actions/setup-go fetches the same tarball.
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
# Install Android SDK command-line tools, accept licenses, install NDK.
|
||||
RUN mkdir -p "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& curl -fsSL -o /tmp/cmdline.zip \
|
||||
"https://dl.google.com/android/repository/commandlinetools-linux-${ANDROID_CMDLINE_TOOLS_VERSION}_latest.zip" \
|
||||
&& unzip -q /tmp/cmdline.zip -d "${ANDROID_HOME}/cmdline-tools" \
|
||||
&& mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" \
|
||||
&& rm /tmp/cmdline.zip \
|
||||
&& yes | sdkmanager --licenses > /dev/null \
|
||||
&& sdkmanager --install "ndk;${ANDROID_NDK_VERSION}" "platform-tools" > /dev/null
|
||||
|
||||
# Install gomobile at the same commit CI pins. Don't run `gomobile init` here:
|
||||
# `init` resolves the NDK at runtime, do it on the first bind in the mounted
|
||||
# workspace so the cache lands on the host volume.
|
||||
RUN GOBIN=/usr/local/bin go install "golang.org/x/mobile/cmd/gomobile@${GOMOBILE_VERSION}" \
|
||||
&& gomobile version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Default entrypoint is a plain shell so the image is composable: callers pass
|
||||
# the full gomobile / gradle command they want to run.
|
||||
CMD ["/bin/bash"]
|
||||
63
docker/build-env/windows-cross/Dockerfile
Normal file
63
docker/build-env/windows-cross/Dockerfile
Normal file
@@ -0,0 +1,63 @@
|
||||
# Windows-cross build environment.
|
||||
#
|
||||
# Cross-compiles Windows .exe targets from a Linux container using
|
||||
# mingw-w64. Mirrors the toolchain set used by
|
||||
# .github/workflows/golang-test-windows.yml insofar as that is possible
|
||||
# without a Windows kernel.
|
||||
#
|
||||
# IMPORTANT — what this image CAN do:
|
||||
# - `GOOS=windows go build ./...` to validate that Windows builds compile
|
||||
# - CGO Windows cross-compile via x86_64-w64-mingw32-gcc when CGO_ENABLED=1
|
||||
# (matches CI's choco-installed mingw-w64)
|
||||
#
|
||||
# IMPORTANT — what this image CANNOT do:
|
||||
# - Run Windows binaries (no Windows kernel under Docker on Linux).
|
||||
# - Replicate the CI's `go test` runs which execute on a real
|
||||
# windows-latest runner (wintun.dll, PsExec, registry, etc.).
|
||||
# Use the CI for that or a native Windows VM.
|
||||
#
|
||||
# Usage (from the netbird repo root):
|
||||
#
|
||||
# docker build -t netbird/build-windows docker/build-env/windows-cross
|
||||
#
|
||||
# # Cross-compile a static client (.exe) from Linux:
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'CGO_ENABLED=1 GOOS=windows GOARCH=amd64 \
|
||||
# CC=x86_64-w64-mingw32-gcc CXX=x86_64-w64-mingw32-g++ \
|
||||
# go build -o netbird.exe ./client'
|
||||
#
|
||||
# # Just validate that everything *compiles* on Windows (no CGO):
|
||||
# docker run --rm -v "$PWD:/src" -w /src netbird/build-windows \
|
||||
# bash -c 'GOOS=windows GOARCH=amd64 go build ./...'
|
||||
#
|
||||
# Tooling versions (keep in sync with go.mod and any future explicit pin
|
||||
# documented in golang-test-windows.yml):
|
||||
# - Ubuntu 22.04
|
||||
# - Go 1.25.5 (matches go.mod)
|
||||
# - mingw-w64 (Ubuntu package — pin further if drift becomes a problem)
|
||||
|
||||
FROM ubuntu:22.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG GO_VERSION=1.25.5
|
||||
|
||||
ENV GOPATH=/go
|
||||
ENV GOTOOLCHAIN=local
|
||||
ENV PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
git \
|
||||
build-essential \
|
||||
mingw-w64 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Go (matches go.mod).
|
||||
RUN curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" \
|
||||
| tar -C /usr/local -xz \
|
||||
&& go version
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,616 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird Enterprise — Getting Started
|
||||
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||
# embedded identity provider. Owner is created via first-login flow.
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_secret() {
|
||||
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||
}
|
||||
|
||||
rand_b64_key() {
|
||||
openssl rand -base64 32
|
||||
}
|
||||
|
||||
check_nb_domain() {
|
||||
local domain="$1"
|
||||
if [[ -z "$domain" ]]; then
|
||||
echo "The domain cannot be empty." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_domain_resolves() {
|
||||
local domain="$1"
|
||||
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||
return 1
|
||||
}
|
||||
|
||||
read_nb_domain() {
|
||||
local value=""
|
||||
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if ! check_nb_domain "$value"; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
if ! check_domain_resolves "$value"; then
|
||||
echo "" > /dev/stderr
|
||||
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||
local confirm=""
|
||||
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||
read -r confirm < /dev/tty
|
||||
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||
read_nb_domain
|
||||
return
|
||||
fi
|
||||
fi
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
# read_yes_no "<prompt>" [<default y|n>]
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
local counter=1
|
||||
while true; do
|
||||
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||
break
|
||||
fi
|
||||
if [[ $counter -eq 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres is taking too long. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
done
|
||||
echo " done"
|
||||
set -e
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
check_openssl
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||
echo "Generated files already exist in $(pwd)."
|
||||
echo "If you want to reinitialize the environment, please remove them first:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||
echo "Be aware this will remove all data from the database."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
echo " Enables traffic events logging on the management server."
|
||||
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||
|
||||
echo ""
|
||||
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||
|
||||
echo ""
|
||||
|
||||
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||
|
||||
POSTGRES_USER="netbird"
|
||||
POSTGRES_DB="netbird"
|
||||
POSTGRES_PASSWORD=$(rand_secret)
|
||||
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||
|
||||
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
echo ""
|
||||
echo "Selected:"
|
||||
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Rendering files into $(pwd) ..."
|
||||
install -m 600 /dev/null .env
|
||||
render_env >> .env
|
||||
render_docker_compose > docker-compose.yml
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||
fi
|
||||
render_caddyfile > Caddyfile
|
||||
install -m 600 /dev/null config.yaml
|
||||
render_config_yaml >> config.yaml
|
||||
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
echo ""
|
||||
echo "Starting postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
sleep 2
|
||||
wait_postgres
|
||||
|
||||
echo ""
|
||||
echo "Starting remaining services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Done."
|
||||
echo ""
|
||||
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||
echo ""
|
||||
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||
echo ""
|
||||
echo "Tail logs:"
|
||||
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
render_env() {
|
||||
cat <<EOF
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
# Domain
|
||||
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||
|
||||
# Image tags. Default to "latest"
|
||||
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# License keys
|
||||
EOF
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
fi
|
||||
cat <<EOF
|
||||
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
EOF
|
||||
|
||||
cat <<EOF
|
||||
|
||||
# Postgres
|
||||
POSTGRES_USER=${POSTGRES_USER}
|
||||
POSTGRES_DB=${POSTGRES_DB}
|
||||
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
# Datastore encryption
|
||||
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
|
||||
# Dashboard OIDC scopes
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_docker_compose() {
|
||||
render_compose_header
|
||||
render_compose_common
|
||||
render_compose_server
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
render_compose_flow
|
||||
fi
|
||||
render_compose_postgres
|
||||
render_compose_footer
|
||||
}
|
||||
|
||||
render_compose_header() {
|
||||
cat <<'EOF'
|
||||
x-default: &default
|
||||
restart: unless-stopped
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: '500m'
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_common() {
|
||||
cat <<'EOF'
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
container_name: netbird-caddy
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||
container_name: netbird-dashboard
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||
- AUTH_AUDIENCE=netbird-dashboard
|
||||
- AUTH_CLIENT_ID=netbird-dashboard
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||
- NGINX_SSL_PORT=443
|
||||
- LETSENCRYPT_DOMAIN=
|
||||
- LETSENCRYPT_EMAIL=
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_server() {
|
||||
cat <<'EOF'
|
||||
netbird-server:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||
container_name: netbird-server
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
dashboard:
|
||||
condition: service_started
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- '3478:3478/udp'
|
||||
volumes:
|
||||
- netbird_data:/var/lib/netbird
|
||||
- ./config.yaml:/etc/netbird/config.yaml
|
||||
command: ["--config", "/etc/netbird/config.yaml"]
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_flow() {
|
||||
cat <<'EOF'
|
||||
nats:
|
||||
<<: *default
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
|
||||
enricher:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||
container_name: netbird-enricher
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
volumes:
|
||||
- netbird_enricher:/var/lib/netbird
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_DATADIR=/var/lib/netbird
|
||||
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_METRICS_PORT=9091
|
||||
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||
|
||||
receiver:
|
||||
<<: *default
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||
container_name: netbird-receiver
|
||||
networks: [netbird]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
- NB_FLOW_LISTEN_PORT=80
|
||||
- NB_FLOW_ADAPTER_TYPE=nats
|
||||
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||
- NB_FLOW_NATS_STREAM=traffic-events
|
||||
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_postgres() {
|
||||
cat <<'EOF'
|
||||
postgres:
|
||||
<<: *default
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
render_compose_footer() {
|
||||
cat <<'EOF'
|
||||
volumes:
|
||||
netbird_data:
|
||||
EOF
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
netbird_nats_data:
|
||||
netbird_enricher:
|
||||
EOF
|
||||
fi
|
||||
cat <<'EOF'
|
||||
netbird_postgres:
|
||||
netbird_caddy_data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
EOF
|
||||
}
|
||||
|
||||
render_caddyfile() {
|
||||
cat <<'EOF'
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
X-Content-Type-Options "nosniff"
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
X-XSS-Protection "1; mode=block"
|
||||
-Server
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:80 {
|
||||
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||
}
|
||||
|
||||
{$CADDY_SECURE_DOMAIN}:443 {
|
||||
import security_headers
|
||||
# Signal (gRPC over h2c)
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||
# Management (gRPC over h2c + HTTP)
|
||||
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||
reverse_proxy /api/* netbird-server:80
|
||||
reverse_proxy /ws-proxy/* netbird-server:80
|
||||
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||
reverse_proxy /oauth2/* netbird-server:80
|
||||
# Relay (WebSocket multiplexed on the same port)
|
||||
reverse_proxy /relay* netbird-server:80
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<'EOF'
|
||||
# Flow receiver (gRPC over h2c)
|
||||
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||
EOF
|
||||
fi
|
||||
|
||||
cat <<'EOF'
|
||||
# Dashboard
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
EOF
|
||||
}
|
||||
|
||||
render_config_yaml() {
|
||||
cat <<EOF
|
||||
# NetBird Enterprise server configuration.
|
||||
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||
|
||||
server:
|
||||
listenAddress: ":80"
|
||||
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||
|
||||
metricsPort: 9090
|
||||
healthcheckAddress: ":9000"
|
||||
|
||||
logLevel: "info"
|
||||
logFile: "console"
|
||||
|
||||
# TLS is terminated by Caddy in front; leave this block empty.
|
||||
tls:
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
|
||||
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||
dataDir: "/var/lib/netbird/"
|
||||
|
||||
disableAnonymousMetrics: false
|
||||
disableGeoliteUpdate: false
|
||||
|
||||
auth:
|
||||
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||
localAuthDisabled: false
|
||||
signKeyRefreshEnabled: false
|
||||
dashboardRedirectURIs:
|
||||
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||
cliRedirectURIs:
|
||||
- "http://localhost:53000/"
|
||||
|
||||
store:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||
|
||||
activityStore:
|
||||
engine: "postgres"
|
||||
dsn: "${POSTGRES_DSN}"
|
||||
EOF
|
||||
|
||||
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
trafficFlow:
|
||||
enabled: true
|
||||
address: "https://${NETBIRD_DOMAIN}:443"
|
||||
interval: "60s"
|
||||
EOF
|
||||
fi
|
||||
}
|
||||
|
||||
init_environment
|
||||
@@ -1,638 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# NetBird — community combined → Enterprise combined migration
|
||||
#
|
||||
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||
# existing files. Original docker-compose.yml and config.yaml are never
|
||||
# modified.
|
||||
#
|
||||
# Steps (all optional, asked interactively):
|
||||
# 1. Image swap — replace community images with enterprise cloud images.
|
||||
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||
#
|
||||
# To revert:
|
||||
# docker compose down
|
||||
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||
# # at the end of this script's run.
|
||||
# docker compose up -d
|
||||
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
return
|
||||
fi
|
||||
if docker compose --help &> /dev/null; then
|
||||
echo "docker compose"
|
||||
return
|
||||
fi
|
||||
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_yq() {
|
||||
if ! command -v yq &> /dev/null; then
|
||||
cat > /dev/stderr <<'EOF'
|
||||
yq is required to parse and update YAML safely.
|
||||
|
||||
macOS: brew install yq
|
||||
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_openssl() {
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
rand_password() {
|
||||
openssl rand -hex 32
|
||||
}
|
||||
|
||||
read_required() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -r value < /dev/tty
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_secret() {
|
||||
local prompt="$1"
|
||||
local value=""
|
||||
while [[ -z "$value" ]]; do
|
||||
echo -n "$prompt: " > /dev/stderr
|
||||
read -rs value < /dev/tty
|
||||
echo "" > /dev/stderr
|
||||
if [[ -z "$value" ]]; then
|
||||
echo "Value cannot be empty." > /dev/stderr
|
||||
fi
|
||||
done
|
||||
echo "$value"
|
||||
}
|
||||
|
||||
read_yes_no() {
|
||||
local prompt="$1"
|
||||
local default="${2:-n}"
|
||||
local hint
|
||||
if [[ "$default" == "y" ]]; then
|
||||
hint="[Y/n]"
|
||||
else
|
||||
hint="[y/N]"
|
||||
fi
|
||||
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||
local ans=""
|
||||
read -r ans < /dev/tty
|
||||
if [[ -z "$ans" ]]; then
|
||||
ans="$default"
|
||||
fi
|
||||
case "$ans" in
|
||||
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||
*) echo "no" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
detect_combined_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_dashboard_service() {
|
||||
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||
}
|
||||
|
||||
detect_config_yaml_host_path() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_data_volume() {
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||
}
|
||||
|
||||
detect_exposed_address() {
|
||||
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||
}
|
||||
|
||||
detect_compose_network() {
|
||||
local tag
|
||||
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||
case "$tag" in
|
||||
"!!seq")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
"!!map")
|
||||
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||
;;
|
||||
*)
|
||||
echo "default"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Renderers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Build docker-compose.override.yml from the steps the operator selected.
|
||||
# Service names match what we detected on the operator's side.
|
||||
render_override() {
|
||||
cat <<EOF
|
||||
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||
|
||||
services:
|
||||
${DASHBOARD_SERVICE}:
|
||||
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||
|
||||
${COMBINED_SERVICE}:
|
||||
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
EOF
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||
|
||||
postgres:
|
||||
image: postgres:17
|
||||
container_name: netbird-postgres
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
environment:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||
POSTGRES_DB: netbird
|
||||
volumes:
|
||||
- netbird_postgres:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
nats:
|
||||
image: nats:2
|
||||
container_name: netbird-nats
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||
volumes:
|
||||
- netbird_nats_data:/data
|
||||
|
||||
flow-enricher:
|
||||
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||
container_name: netbird-flow-enricher
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_DATADIR: /var/lib/netbird
|
||||
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_METRICS_PORT: 9091
|
||||
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||
|
||||
flow-receiver:
|
||||
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||
container_name: netbird-flow-receiver
|
||||
restart: unless-stopped
|
||||
networks: [${COMPOSE_NETWORK}]
|
||||
depends_on:
|
||||
nats:
|
||||
condition: service_started
|
||||
environment:
|
||||
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||
NB_FLOW_LISTEN_PORT: 80
|
||||
NB_FLOW_ADAPTER_TYPE: nats
|
||||
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||
NB_FLOW_NATS_STREAM: traffic-events
|
||||
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-flow.tls=true
|
||||
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||
- traefik.http.routers.netbird-flow.priority=100
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Volume declarations for anything new the override introduced
|
||||
local has_volumes="no"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
has_volumes="yes"
|
||||
fi
|
||||
|
||||
if [[ "$has_volumes" == "yes" ]]; then
|
||||
cat <<EOF
|
||||
|
||||
volumes:
|
||||
EOF
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo " netbird_postgres:"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo " netbird_nats_data:"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||
# config.yaml. We don't touch the original file.
|
||||
render_enterprise_config() {
|
||||
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||
|
||||
yq eval "
|
||||
.server.store.engine = \"postgres\" |
|
||||
.server.store.dsn = \"$pg_dsn\" |
|
||||
.server.activityStore.engine = \"postgres\" |
|
||||
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||
.server.authStore.engine = \"postgres\" |
|
||||
.server.authStore.dsn = \"$pg_dsn\"
|
||||
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
local flow_addr="${NETBIRD_DOMAIN}"
|
||||
yq eval -i "
|
||||
.server.trafficFlow.enabled = true |
|
||||
.server.trafficFlow.address = \"$flow_addr\" |
|
||||
.server.trafficFlow.interval = \"60s\"
|
||||
" "$ENTERPRISE_CONFIG_FILE"
|
||||
fi
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
resolve_data_volume() {
|
||||
local short="$1"
|
||||
local actual
|
||||
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||
echo "$actual"
|
||||
return
|
||||
fi
|
||||
# Relative bind mount: docker-compose resolves it against the compose
|
||||
# file's directory, but `docker run -v` resolves it against the current
|
||||
# working directory. Normalize to an absolute path so both interpretations
|
||||
# agree (and the printed revert command works from any CWD).
|
||||
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||
local compose_dir
|
||||
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||
(
|
||||
cd "$compose_dir"
|
||||
cd "$(dirname "$short")"
|
||||
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||
)
|
||||
return
|
||||
fi
|
||||
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||
echo "$short"
|
||||
}
|
||||
|
||||
backup_sqlite() {
|
||||
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
local data_volume_actual
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||
docker run --rm \
|
||||
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
busybox \
|
||||
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||
local copied
|
||||
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||
if [[ -z "$copied" ]]; then
|
||||
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo " done"
|
||||
}
|
||||
|
||||
run_migrate_store() {
|
||||
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||
echo " done"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration() {
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_yq
|
||||
check_openssl
|
||||
|
||||
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||
|
||||
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||
echo "Migration artifacts already exist in $(pwd):"
|
||||
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo ""
|
||||
echo "Either you've already migrated, or a previous run was interrupted."
|
||||
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMBINED_SERVICE=$(detect_combined_service)
|
||||
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||
DATA_VOLUME=$(detect_data_volume)
|
||||
COMPOSE_NETWORK=$(detect_compose_network)
|
||||
|
||||
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$DATA_VOLUME" ]]; then
|
||||
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Detected existing deployment:"
|
||||
echo " Combined service: $COMBINED_SERVICE"
|
||||
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||
echo " Data volume: $DATA_VOLUME"
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Step 1 — always (this is the point of the script)
|
||||
MIGRATE_IMAGES="yes"
|
||||
echo ""
|
||||
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||
NB_LICENSE_KEY=$(read_secret " License key")
|
||||
GHCR_USERNAME="netbirdExtAccess1"
|
||||
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||
|
||||
# Step 2 — optional
|
||||
echo ""
|
||||
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||
echo " will be backed up automatically. To fully revert later, restore"
|
||||
echo " that backup and delete docker-compose.override.yml +"
|
||||
echo " config.yaml.enterprise."
|
||||
local confirm
|
||||
confirm=$(read_yes_no " Continue?" "y")
|
||||
if [[ "$confirm" != "yes" ]]; then
|
||||
MIGRATE_POSTGRES="no"
|
||||
echo " Skipping Postgres migration."
|
||||
else
|
||||
POSTGRES_PASSWORD=$(rand_password)
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||
echo ""
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
# Auth secret MUST match server.authSecret from config.yaml
|
||||
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||
fi
|
||||
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||
|
||||
# We need the encryption key from the existing config.yaml for the enricher
|
||||
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
ENABLE_FLOW="no"
|
||||
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||
fi
|
||||
}
|
||||
|
||||
apply_changes() {
|
||||
echo ""
|
||||
echo "Writing $OVERRIDE_FILE ..."
|
||||
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||
render_override > "$OVERRIDE_FILE"
|
||||
|
||||
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||
fi
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||
render_enterprise_config
|
||||
fi
|
||||
|
||||
# Persist secrets that the override file references via env interpolation.
|
||||
# We write them to a .env file in the current directory; docker compose
|
||||
# picks it up automatically.
|
||||
echo "Writing .env additions (mode 600) ..."
|
||||
local ENV_FILE=".env"
|
||||
touch "$ENV_FILE"
|
||||
chmod 600 "$ENV_FILE"
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
fi
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||
fi
|
||||
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||
fi
|
||||
} >> "$ENV_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Logging in to ghcr.io ..."
|
||||
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||
unset GHCR_TOKEN
|
||||
|
||||
echo ""
|
||||
echo "Pulling enterprise images ..."
|
||||
$DOCKER_COMPOSE_COMMAND pull
|
||||
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
echo ""
|
||||
echo "Stopping existing services (volumes preserved) ..."
|
||||
$DOCKER_COMPOSE_COMMAND down
|
||||
|
||||
backup_sqlite
|
||||
|
||||
echo ""
|
||||
echo "Starting Postgres ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||
|
||||
# Wait for healthy
|
||||
local counter=0
|
||||
echo -n "Waiting for Postgres to become ready"
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||
echo -n " ."
|
||||
sleep 2
|
||||
counter=$((counter + 1))
|
||||
if [[ $counter -ge 60 ]]; then
|
||||
echo ""
|
||||
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo " done"
|
||||
|
||||
run_migrate_store
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Bringing up all services ..."
|
||||
$DOCKER_COMPOSE_COMMAND up -d
|
||||
|
||||
echo ""
|
||||
echo "Migration complete."
|
||||
}
|
||||
|
||||
print_summary() {
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Summary"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " Images: swapped to enterprise"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||
echo ""
|
||||
echo " Generated files (next to your docker-compose.yml):"
|
||||
echo " $OVERRIDE_FILE"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||
echo " .env (license key + secrets, mode 600)"
|
||||
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||
echo ""
|
||||
echo " Tail logs:"
|
||||
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||
echo ""
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " To revert"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||
# Resolve project-prefixed volume names now (before override is removed).
|
||||
local pg_volume data_volume_actual
|
||||
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||
echo " docker volume rm $pg_volume"
|
||||
echo " # Restore SQLite from the backup created during this run:"
|
||||
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||
fi
|
||||
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||
echo "──────────────────────────────────────────────────────────────────────"
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
init_migration
|
||||
apply_changes
|
||||
print_summary
|
||||
@@ -1,89 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var adminDatadir string
|
||||
|
||||
// newAdminCommands creates the admin command tree with management-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources initializes logging, loads config, opens the management store
|
||||
// and embedded IdP storage, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if adminDatadir != "" {
|
||||
oldDatadir := datadir
|
||||
datadir = adminDatadir
|
||||
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
|
||||
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
|
||||
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
|
||||
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, config)
|
||||
}
|
||||
@@ -1,441 +0,0 @@
|
||||
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own opener to handle config loading and storage initialization.
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
localConnectorID = "local"
|
||||
dashboardClientID = "netbird-dashboard"
|
||||
cliClientID = "netbird-cli"
|
||||
defaultTOTPAuthenticatorID = "default-totp"
|
||||
)
|
||||
|
||||
// Resources contains the storages required by the admin commands.
|
||||
type Resources struct {
|
||||
Store store.Store
|
||||
IDPStorage storage.Storage
|
||||
}
|
||||
|
||||
// Opener initializes command resources from the command context and calls fn.
|
||||
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
|
||||
|
||||
type userSelector struct {
|
||||
email string
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s userSelector) normalized() userSelector {
|
||||
return userSelector{
|
||||
email: strings.TrimSpace(s.email),
|
||||
userID: strings.TrimSpace(s.userID),
|
||||
}
|
||||
}
|
||||
|
||||
func (s userSelector) validate() error {
|
||||
s = s.normalized()
|
||||
if (s.email == "") == (s.userID == "") {
|
||||
return fmt.Errorf("provide exactly one of --email or --user-id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCommands creates the admin command tree with the given resource opener.
|
||||
func NewCommands(opener Opener) *cobra.Command {
|
||||
adminCmd := &cobra.Command{
|
||||
Use: "admin",
|
||||
Short: "Self-hosted administrator helpers",
|
||||
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
|
||||
}
|
||||
|
||||
userCmd := &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "Manage local embedded IdP users",
|
||||
}
|
||||
|
||||
var passwordSelector userSelector
|
||||
var password string
|
||||
var passwordFile string
|
||||
passwordCmd := &cobra.Command{
|
||||
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
|
||||
Aliases: []string{"set-password"},
|
||||
Short: "Change a local user's password",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(passwordCmd, &passwordSelector)
|
||||
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
|
||||
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
|
||||
|
||||
var resetSelector userSelector
|
||||
resetMFACmd := &cobra.Command{
|
||||
Use: "reset-mfa (--email email | --user-id id)",
|
||||
Short: "Reset a local user's MFA enrollment",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(resetMFACmd, &resetSelector)
|
||||
|
||||
userCmd.AddCommand(passwordCmd, resetMFACmd)
|
||||
|
||||
mfaCmd := &cobra.Command{
|
||||
Use: "mfa",
|
||||
Short: "Manage local MFA for embedded IdP users",
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable",
|
||||
Short: "Enable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable",
|
||||
Short: "Disable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show local MFA status",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
|
||||
adminCmd.AddCommand(userCmd, mfaCmd)
|
||||
return adminCmd
|
||||
}
|
||||
|
||||
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
|
||||
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
|
||||
}
|
||||
|
||||
yamlConfig, err := cfg.ToYAMLConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build embedded IdP config: %w", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
st, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
|
||||
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
|
||||
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
|
||||
}
|
||||
|
||||
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
|
||||
if password != "" && passwordFile != "" {
|
||||
return "", fmt.Errorf("provide only one of --password or --password-file")
|
||||
}
|
||||
if passwordFile == "" {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
if passwordFile == "-" {
|
||||
data, err = io.ReadAll(cmd.InOrStdin())
|
||||
} else {
|
||||
data, err = os.ReadFile(passwordFile)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
return strings.TrimRight(string(data), "\r\n"), nil
|
||||
}
|
||||
|
||||
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
if err := server.ValidatePassword(password); err != nil {
|
||||
return fmt.Errorf("invalid password: %w", err)
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||
old.Hash = hash
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update password for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reset := false
|
||||
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
|
||||
old.MFASecrets = map[string]*storage.MFASecret{}
|
||||
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
|
||||
return old, nil
|
||||
})
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reset {
|
||||
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
if len(accounts) != 1 {
|
||||
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
|
||||
}
|
||||
|
||||
settings := &types.Settings{}
|
||||
if accounts[0].Settings != nil {
|
||||
settings = accounts[0].Settings.Copy()
|
||||
}
|
||||
settings.LocalMfaEnabled = enabled
|
||||
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
|
||||
return fmt.Errorf("save local MFA account setting: %w", err)
|
||||
}
|
||||
|
||||
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := "disabled"
|
||||
if enabled {
|
||||
state = "enabled"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
accountStatus := "unknown"
|
||||
if len(accounts) == 1 && accounts[0].Settings != nil {
|
||||
accountStatus = "disabled"
|
||||
if accounts[0].Settings.LocalMfaEnabled {
|
||||
accountStatus = "enabled"
|
||||
}
|
||||
}
|
||||
|
||||
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
|
||||
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return storage.Password{}, err
|
||||
}
|
||||
|
||||
if selector.email != "" {
|
||||
user, err := idpStorage.GetPassword(ctx, selector.email)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
rawUserID := selector.userID
|
||||
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
|
||||
rawUserID = decodedUserID
|
||||
}
|
||||
|
||||
users, err := idpStorage.ListPasswords(ctx)
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("list local users: %w", err)
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.UserID == rawUserID || user.UserID == selector.userID {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
|
||||
}
|
||||
|
||||
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
|
||||
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
|
||||
if err == nil || errors.Is(err, storage.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{defaultTOTPAuthenticatorID}
|
||||
}
|
||||
|
||||
for _, clientID := range []string{cliClientID, dashboardClientID} {
|
||||
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = mfaChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
|
||||
}
|
||||
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
|
||||
clientIDs := []string{cliClientID, dashboardClientID}
|
||||
enabledCount := 0
|
||||
for _, clientID := range clientIDs {
|
||||
client, err := idpStorage.GetClient(ctx, clientID)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
|
||||
}
|
||||
if err != nil {
|
||||
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
|
||||
enabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
switch enabledCount {
|
||||
case 0:
|
||||
return "disabled", nil
|
||||
case len(clientIDs):
|
||||
return "enabled", nil
|
||||
default:
|
||||
return "partially enabled", nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasAuthenticator(chain []string, authenticatorID string) bool {
|
||||
for _, id := range chain {
|
||||
if id == authenticatorID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
)
|
||||
|
||||
func newTestIDPStorage(t *testing.T) storage.Storage {
|
||||
t.Helper()
|
||||
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
|
||||
Email: "user@example.com",
|
||||
Username: "User",
|
||||
UserID: "user-1",
|
||||
Hash: hash,
|
||||
}))
|
||||
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
MFASecrets: map[string]*storage.MFASecret{
|
||||
defaultTOTPAuthenticatorID: {
|
||||
AuthenticatorID: defaultTOTPAuthenticatorID,
|
||||
Type: "TOTP",
|
||||
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
|
||||
Confirmed: true,
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
|
||||
"webauthn": {{CredentialID: []byte("credential")}},
|
||||
},
|
||||
}))
|
||||
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
Nonce: "nonce",
|
||||
}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
|
||||
|
||||
return st
|
||||
}
|
||||
|
||||
func TestRunChangePassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "Password updated")
|
||||
|
||||
user, err := st.GetPassword(ctx, "user@example.com")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunChangePasswordValidatesPassword(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid password")
|
||||
}
|
||||
|
||||
func TestRunResetMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
|
||||
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "MFA reset")
|
||||
|
||||
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, identity.MFASecrets)
|
||||
require.Empty(t, identity.WebAuthnCredentials)
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
old.MFASecrets = nil
|
||||
old.WebAuthnCredentials = nil
|
||||
return old, nil
|
||||
}))
|
||||
|
||||
var out bytes.Buffer
|
||||
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "No MFA enrollment found")
|
||||
}
|
||||
|
||||
func TestSetIDPClientsMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, true))
|
||||
status, err := idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "enabled", status)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, false))
|
||||
status, err = idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "disabled", status)
|
||||
}
|
||||
|
||||
func TestUserSelectorValidate(t *testing.T) {
|
||||
require.NoError(t, userSelector{email: " user@example.com "}.validate())
|
||||
require.NoError(t, userSelector{userID: "user-1"}.validate())
|
||||
require.Error(t, userSelector{}.validate())
|
||||
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
|
||||
}
|
||||
|
||||
func TestFindLocalUserNotFound(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "not found"))
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputFromStdin(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(strings.NewReader("NewPass1!\n"))
|
||||
|
||||
password, err := resolvePasswordInput(cmd, "", "-")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "NewPass1!", password)
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
|
||||
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
ac := newAdminCommands()
|
||||
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(ac)
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
}
|
||||
|
||||
55
management/cmd/token.go
Normal file
55
management/cmd/token.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -33,8 +33,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -84,9 +82,6 @@ type ProxyServiceServer struct {
|
||||
// Manager for users
|
||||
usersManager users.Manager
|
||||
|
||||
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||
idpManager idp.Manager
|
||||
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
@@ -162,7 +157,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -171,7 +166,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
idpManager: idpManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
@@ -1708,7 +1702,22 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
@@ -1745,45 +1754,6 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||
// user or peer ID, and peer name or user email.
|
||||
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||
// principal so multiple peers owned by the same user share a single
|
||||
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||
// user.Email when linked, peer.Name when not.
|
||||
|
||||
// If the peer isn't associated with a user, return the peer info directly.
|
||||
if peer.UserID == "" {
|
||||
return peer.ID, peer.Name
|
||||
}
|
||||
|
||||
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||
// if an IdP is available, we gather details on the user from it.
|
||||
principalID := peer.UserID
|
||||
displayIdentity := peer.Name
|
||||
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
// IdP enrichment wins when available — the stored email column is a
|
||||
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||
if s.idpManager != nil {
|
||||
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||
displayIdentity = ud.Email
|
||||
} else if uerr != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||
}
|
||||
}
|
||||
|
||||
return principalID, displayIdentity
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
|
||||
@@ -3,19 +3,14 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
@@ -142,52 +137,6 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||
// panics if any unexpected method is invoked).
|
||||
type mockTunnelPeersManager struct {
|
||||
peers.Manager
|
||||
peer *peer.Peer
|
||||
peerErr error
|
||||
groups []*types.Group
|
||||
groupsErr error
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||
return m.peer, m.peerErr
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||
return m.peer, m.groups, m.groupsErr
|
||||
}
|
||||
|
||||
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||
// an IdP that knows nothing about the user.
|
||||
type mockTunnelIdpManager struct {
|
||||
idp.Manager
|
||||
email string
|
||||
hasData bool
|
||||
err error
|
||||
gotCalls int
|
||||
gotMeta []idp.AppMetadata
|
||||
}
|
||||
|
||||
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||
m.gotCalls++
|
||||
m.gotMeta = append(m.gotMeta, meta)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if !m.hasData {
|
||||
// This might not be a thing any of the actual IDP implementations do,
|
||||
// i.e. return a nil value with no error, but it seems valuable to test
|
||||
// that behavior here.
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -405,163 +354,6 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||
// (IdP email -> stored User.Email -> peer.Name).
|
||||
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||
const (
|
||||
domain = "app.example.com"
|
||||
accountID = "account1"
|
||||
peerID = "peer1"
|
||||
peerName = "peer-display-name"
|
||||
userID = "user1"
|
||||
)
|
||||
|
||||
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
peerUserID string
|
||||
storedUsers map[string]*types.User
|
||||
storedErr error
|
||||
noIdP bool
|
||||
idpEmail string
|
||||
idpHasData bool
|
||||
idpErr error
|
||||
expectEmail string
|
||||
expectUserID string
|
||||
expectIdPHit bool
|
||||
}{
|
||||
{
|
||||
name: "idp email wins over stored email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp returns empty email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "",
|
||||
idpHasData: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp has no data",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpHasData: false,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp errors",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpErr: errors.New("idp unreachable"),
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when no idp manager",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
noIdP: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored email is empty",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||
peerUserID: userID,
|
||||
storedUsers: map[string]*types.User{},
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "unlinked peer uses peer name and never consults idp",
|
||||
peerUserID: "",
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: peerID,
|
||||
expectIdPHit: false,
|
||||
},
|
||||
{
|
||||
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
noIdP: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: userID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||
},
|
||||
peersManager: &mockTunnelPeersManager{
|
||||
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||
},
|
||||
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||
}
|
||||
|
||||
var idpMock *mockTunnelIdpManager
|
||||
if !tt.noIdP {
|
||||
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||
server.idpManager = idpMock
|
||||
}
|
||||
|
||||
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||
Domain: domain,
|
||||
TunnelIp: "100.64.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.GetValid(), "expected access granted")
|
||||
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||
|
||||
if idpMock != nil {
|
||||
if tt.expectIdPHit {
|
||||
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||
require.Len(t, idpMock.gotMeta, 1)
|
||||
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||
} else {
|
||||
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -217,7 +217,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -982,6 +982,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
var peer *nbpeer.Peer
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var err error
|
||||
var postureChecks []*posture.Checks
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -1009,8 +1011,13 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
@@ -1018,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -1025,11 +1037,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
@@ -1040,9 +1047,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1117,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var shouldStorePeer, shouldUpdatePeers bool
|
||||
var shouldStorePeer bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1144,10 +1151,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
shouldUpdatePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
@@ -1169,15 +1180,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
@@ -1187,7 +1190,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if shouldUpdatePeers {
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1283,22 +1286,12 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
@@ -1306,16 +1299,32 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
|
||||
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
|
||||
for _, peerID := range peerGroupIDs {
|
||||
groupIDsMap[peerID] = struct{}{}
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1327,7 +1336,11 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
|
||||
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||
}
|
||||
|
||||
@@ -1340,19 +1353,29 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
}
|
||||
|
||||
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
||||
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
|
||||
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
if slices.Contains(peerGroupIDs, sourceGroup) {
|
||||
return policy.SourcePostureChecks
|
||||
group, ok := sourceGroups[sourceGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to check peer in policy source group")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return policy.SourcePostureChecks, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
@@ -166,7 +162,49 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
return len(metaDiff(p, other)) == 0
|
||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
||||
})
|
||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
||||
})
|
||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
||||
})
|
||||
if !equalNetworkAddresses {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(p.Files, func(i, j int) bool {
|
||||
return p.Files[i].Path < p.Files[j].Path
|
||||
})
|
||||
sort.Slice(other.Files, func(i, j int) bool {
|
||||
return other.Files[i].Path < other.Files[j].Path
|
||||
})
|
||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
||||
})
|
||||
if !equalFiles {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Hostname == other.Hostname &&
|
||||
p.GoOS == other.GoOS &&
|
||||
p.Kernel == other.Kernel &&
|
||||
p.KernelVersion == other.KernelVersion &&
|
||||
p.Core == other.Core &&
|
||||
p.Platform == other.Platform &&
|
||||
p.OS == other.OS &&
|
||||
p.OSVersion == other.OSVersion &&
|
||||
p.WtVersion == other.WtVersion &&
|
||||
p.UIVersion == other.UIVersion &&
|
||||
p.SystemSerialNumber == other.SystemSerialNumber &&
|
||||
p.SystemProductName == other.SystemProductName &&
|
||||
p.SystemManufacturer == other.SystemManufacturer &&
|
||||
p.Environment.Cloud == other.Environment.Cloud &&
|
||||
p.Environment.Platform == other.Environment.Platform &&
|
||||
p.Flags.isEqual(other.Flags) &&
|
||||
capabilitiesEqual(p.Capabilities, other.Capabilities)
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEmpty() bool {
|
||||
@@ -258,7 +296,7 @@ func (p *Peer) Copy() *Peer {
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
if meta.isEmpty() {
|
||||
return updated, versionChanged
|
||||
}
|
||||
@@ -270,121 +308,14 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (update
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := metaDiff(p.Meta, meta)
|
||||
if len(diff) != 0 {
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
if p.Meta.isEqual(meta) {
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if versionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if len(diff) > 0 || versionChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||
}
|
||||
|
||||
p.Meta = meta
|
||||
updated = true
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
// metaDiff returns a human-readable list of the fields that differ between the
|
||||
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||
// diff, so the log line can never disagree with the change decision. Slices are
|
||||
// cloned before sorting, so callers' meta is not mutated.
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
var diff []string
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
// same multiplicity, ignoring order. The element type is the comparison key, so
|
||||
// every field participates in equality.
|
||||
func sameMultiset[T comparable](a, b []T) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
counts := make(map[T]int, len(a))
|
||||
for _, v := range a {
|
||||
counts[v]++
|
||||
}
|
||||
for _, v := range b {
|
||||
counts[v]--
|
||||
if counts[v] == 0 {
|
||||
delete(counts, v)
|
||||
}
|
||||
}
|
||||
return len(counts) == 0
|
||||
}
|
||||
|
||||
// GetLastLogin returns the last login time of the peer.
|
||||
func (p *Peer) GetLastLogin() time.Time {
|
||||
if p.LastLogin != nil {
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
|
||||
// map 1:1 to a single diff entry. Today the only such field is Environment, which
|
||||
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
|
||||
// entry beyond its single struct field. If you teach metaDiff to explode another
|
||||
// field into N entries, bump this by N-1; if you collapse a field, lower it.
|
||||
const metaDiffExtraEntries = 1
|
||||
|
||||
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
|
||||
// values and diffs it against the zero value, then asserts metaDiff emits exactly
|
||||
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
|
||||
//
|
||||
// The expected count is derived from the struct via reflection, so adding a field
|
||||
// to PeerSystemMeta raises the expectation automatically — but the actual diff
|
||||
// only grows if metaDiff was taught to compare the new field. A mismatch means
|
||||
// someone changed the struct without updating metaDiff (or this test's
|
||||
// extra-entry accounting), which is exactly what we want to catch.
|
||||
func TestMetaDiff_CoversAllFields(t *testing.T) {
|
||||
var full PeerSystemMeta
|
||||
exported := populateAll(t, reflect.ValueOf(&full).Elem())
|
||||
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
|
||||
|
||||
diff := metaDiff(PeerSystemMeta{}, full)
|
||||
|
||||
require.Len(t, diff, exported+metaDiffExtraEntries,
|
||||
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
|
||||
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
|
||||
"diff was: %v", diff)
|
||||
|
||||
require.False(t, full.isEqual(PeerSystemMeta{}),
|
||||
"isEqual must report a fully-populated meta as different from the zero value")
|
||||
}
|
||||
|
||||
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
|
||||
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
|
||||
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
|
||||
// compare would not change the diff count. This flips each Flags field on its own
|
||||
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
|
||||
// fails here.
|
||||
func TestFlags_isEqualChecksEveryField(t *testing.T) {
|
||||
typ := reflect.TypeOf(Flags{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
require.Equal(t, reflect.Bool, f.Type.Kind(),
|
||||
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
|
||||
|
||||
var a, b Flags
|
||||
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
|
||||
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// populateAll sets every exported field of the struct to a deterministic non-zero
|
||||
// value, recursing into nested structs and the element type of struct slices so
|
||||
// that each leaf differs from zero. It returns the number of exported fields on
|
||||
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
|
||||
// settable exported fields and is comparable with ==).
|
||||
func populateAll(t *testing.T, v reflect.Value) int {
|
||||
t.Helper()
|
||||
|
||||
typ := v.Type()
|
||||
exported := 0
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
if f.PkgPath != "" { // unexported
|
||||
continue
|
||||
}
|
||||
exported++
|
||||
setNonZero(t, v.Field(i))
|
||||
}
|
||||
return exported
|
||||
}
|
||||
|
||||
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
|
||||
// recursing into nested structs and populating one element of slice fields.
|
||||
func setNonZero(t *testing.T, field reflect.Value) {
|
||||
t.Helper()
|
||||
|
||||
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
|
||||
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
|
||||
return
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
field.SetString("non-zero")
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(7)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.SetUint(7)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.SetFloat(7)
|
||||
case reflect.Struct:
|
||||
populateAll(t, field)
|
||||
case reflect.Slice:
|
||||
s := reflect.MakeSlice(field.Type(), 1, 1)
|
||||
setNonZero(t, s.Index(0))
|
||||
field.Set(s)
|
||||
default:
|
||||
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
|
||||
}
|
||||
}
|
||||
@@ -1847,17 +1847,12 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements.
|
||||
func validatePassword(password string) error {
|
||||
return ValidatePassword(password)
|
||||
}
|
||||
|
||||
// ValidatePassword checks password strength requirements:
|
||||
// validatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func ValidatePassword(password string) error {
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
@@ -125,7 +125,6 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -21,8 +21,7 @@ AWK_FIRST_FIELD='{print $1}'
|
||||
|
||||
fetch_all_tags() {
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||
grep -iv 'rc' | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
|
||||
@@ -32,8 +32,7 @@ fetch_current_ports_version() {
|
||||
fetch_all_tags() {
|
||||
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||
grep -iv 'rc' | \
|
||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
||||
sed 's/.*\/v//' | \
|
||||
sort -u -V
|
||||
return 0
|
||||
|
||||
@@ -33,7 +33,7 @@ type Client interface {
|
||||
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||
Ready() bool
|
||||
IsHealthy() bool
|
||||
WaitStreamConnected(context.Context)
|
||||
WaitStreamConnected()
|
||||
SendToStream(msg *proto.EncryptedMessage) error
|
||||
Send(msg *proto.Message) error
|
||||
SetOnReconnectedListener(func())
|
||||
|
||||
@@ -65,10 +65,7 @@ var _ = Describe("GrpcClient", func() {
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctxA, cancelA := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelA()
|
||||
clientA.WaitStreamConnected(ctxA)
|
||||
Expect(clientA.StreamConnected()).To(BeTrue())
|
||||
clientA.WaitStreamConnected()
|
||||
|
||||
// connect PeerB to Signal
|
||||
keyB, _ := wgtypes.GenerateKey()
|
||||
@@ -94,10 +91,7 @@ var _ = Describe("GrpcClient", func() {
|
||||
}
|
||||
}()
|
||||
|
||||
ctxB, cancelB := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelB()
|
||||
clientB.WaitStreamConnected(ctxB)
|
||||
Expect(clientB.StreamConnected()).To(BeTrue())
|
||||
clientB.WaitStreamConnected()
|
||||
|
||||
// PeerA initiates ping-pong
|
||||
err := clientA.Send(&sigProto.Message{
|
||||
@@ -135,10 +129,8 @@ var _ = Describe("GrpcClient", func() {
|
||||
return
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
client.WaitStreamConnected(ctx)
|
||||
Expect(client.StreamConnected()).To(BeTrue())
|
||||
client.WaitStreamConnected()
|
||||
Expect(client).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -246,6 +246,15 @@ func (c *GrpcClient) notifyStreamConnected() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
if c.connectedCh == nil {
|
||||
c.connectedCh = make(chan struct{})
|
||||
}
|
||||
return c.connectedCh
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
c.stream = nil
|
||||
|
||||
@@ -301,24 +310,14 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
}
|
||||
|
||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||
func (c *GrpcClient) WaitStreamConnected(ctx context.Context) {
|
||||
// Check the status and obtain the wait channel atomically: otherwise
|
||||
// notifyStreamConnected could flip the status and close/clear the channel
|
||||
// between the check and the channel creation, leaving us waiting forever on
|
||||
// a stale channel.
|
||||
c.mux.Lock()
|
||||
func (c *GrpcClient) WaitStreamConnected() {
|
||||
|
||||
if c.status == StreamConnected {
|
||||
c.mux.Unlock()
|
||||
return
|
||||
}
|
||||
if c.connectedCh == nil {
|
||||
c.connectedCh = make(chan struct{})
|
||||
}
|
||||
ch := c.connectedCh
|
||||
c.mux.Unlock()
|
||||
|
||||
ch := c.getStreamStatusChan()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.ctx.Done():
|
||||
case <-ch:
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (sm *MockClient) Ready() bool {
|
||||
return sm.ReadyFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) WaitStreamConnected(context.Context) {
|
||||
func (sm *MockClient) WaitStreamConnected() {
|
||||
if sm.WaitStreamConnectedFunc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
|
||||
|
||||
streamReady := make(chan struct{})
|
||||
go func() {
|
||||
client.WaitStreamConnected(ctx)
|
||||
client.WaitStreamConnected()
|
||||
close(streamReady)
|
||||
}()
|
||||
select {
|
||||
|
||||
@@ -140,7 +140,12 @@ func newRotatedOutput(logPath string) io.Writer {
|
||||
func setGRPCLibLogger(logger *log.Logger) {
|
||||
logOut := logger.Writer()
|
||||
if os.Getenv("GRPC_GO_LOG_SEVERITY_LEVEL") != "info" {
|
||||
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, logOut, logOut))
|
||||
// Discard grpc info AND warning logs by default — the warning stream is
|
||||
// dominated by benign connection-retry noise ("addrConn.createTransport
|
||||
// failed", "transport is closing") that surfaces e.g. when the CLI dials
|
||||
// a daemon that is still starting or already gone. Errors are kept. Set
|
||||
// GRPC_GO_LOG_SEVERITY_LEVEL=info to get the full verbose grpc logging.
|
||||
grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, io.Discard, logOut))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user