mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-22 15:59:59 +00:00
Compare commits
21 Commits
fix/ipv6-a
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
520370a8b0 | ||
|
|
af3b7e4497 | ||
|
|
e84f6527f7 | ||
|
|
ac9529ea8c | ||
|
|
f736ef9647 | ||
|
|
cf58bf1ba9 | ||
|
|
522b8ed969 | ||
|
|
c9e99659ea | ||
|
|
58c79f5878 | ||
|
|
15a0504fb1 | ||
|
|
883a1a8961 | ||
|
|
54192a94b7 | ||
|
|
8511687270 | ||
|
|
35b465fa4a | ||
|
|
fb87f751a5 | ||
|
|
679c7182a4 | ||
|
|
8c031ea6f0 | ||
|
|
60a9544656 | ||
|
|
d3710d4bb2 | ||
|
|
b5a16a1898 | ||
|
|
449b5cbb80 |
@@ -20,7 +20,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -59,12 +59,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: true
|
cache: true
|
||||||
|
|||||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||||
|
|||||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
|
|||||||
4
.github/workflows/golang-test-freebsd.yml
vendored
4
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ jobs:
|
|||||||
id: test
|
id: test
|
||||||
env:
|
env:
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
|
|||||||
52
.github/workflows/golang-test-linux.yml
vendored
52
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
management: ${{ steps.filter.outputs.management }}
|
management: ${{ steps.filter.outputs.management }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- 'management/**'
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -119,12 +119,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -162,7 +162,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -175,12 +175,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -246,12 +246,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -290,7 +290,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -306,12 +306,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -347,7 +347,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -363,12 +363,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -407,7 +407,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -424,12 +424,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -484,7 +484,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -529,12 +529,12 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -623,12 +623,12 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -692,12 +692,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -734,7 +734,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
|
|||||||
4
.github/workflows/golang-test-windows.yml
vendored
4
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|||||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: codespell
|
- name: codespell
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Check for duplicate constants
|
- name: Check for duplicate constants
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|||||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|||||||
10
.github/workflows/mobile-build-validation.yml
vendored
10
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
@@ -28,7 +28,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
cmdline-tools-version: 8512546
|
cmdline-tools-version: 8512546
|
||||||
- name: Setup Java
|
- name: Setup Java
|
||||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
||||||
with:
|
with:
|
||||||
java-version: "11"
|
java-version: "11"
|
||||||
distribution: "adopt"
|
distribution: "adopt"
|
||||||
@@ -54,11 +54,11 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
|
|||||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ jobs:
|
|||||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
env:
|
env:
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
@@ -135,7 +135,7 @@ jobs:
|
|||||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -166,7 +166,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -186,9 +186,9 @@ jobs:
|
|||||||
- name: check git status
|
- name: check git status
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
@@ -221,7 +221,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --clean ${{ env.flags }}
|
args: release --clean ${{ env.flags }}
|
||||||
@@ -347,7 +347,7 @@ jobs:
|
|||||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -374,7 +374,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -420,7 +420,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||||
@@ -464,12 +464,12 @@ jobs:
|
|||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -488,7 +488,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||||
@@ -522,7 +522,7 @@ jobs:
|
|||||||
downloadPath: '${{ github.workspace }}\temp'
|
downloadPath: '${{ github.workspace }}\temp'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -534,13 +534,13 @@ jobs:
|
|||||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
|
|
||||||
- name: Download release artifacts
|
- name: Download release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: release
|
path: release
|
||||||
|
|
||||||
- name: Download UI release artifacts
|
- name: Download UI release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: release-ui
|
path: release-ui
|
||||||
|
|||||||
12
.github/workflows/test-infrastructure-files.yml
vendored
12
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,12 +68,12 @@ jobs:
|
|||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ jobs:
|
|||||||
- name: Build management docker image
|
- name: Build management docker image
|
||||||
working-directory: management
|
working-directory: management
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/management:latest .
|
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: Build signal binary
|
- name: Build signal binary
|
||||||
working-directory: signal
|
working-directory: signal
|
||||||
@@ -216,7 +216,7 @@ jobs:
|
|||||||
- name: Build signal docker image
|
- name: Build signal docker image
|
||||||
working-directory: signal
|
working-directory: signal
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/signal:latest .
|
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: Build relay binary
|
- name: Build relay binary
|
||||||
working-directory: relay
|
working-directory: relay
|
||||||
@@ -225,7 +225,7 @@ jobs:
|
|||||||
- name: Build relay docker image
|
- name: Build relay docker image
|
||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/relay:latest .
|
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files/artifacts
|
working-directory: infrastructure_files/artifacts
|
||||||
@@ -256,7 +256,7 @@ jobs:
|
|||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
|||||||
GOARCH: wasm
|
GOARCH: wasm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -44,11 +44,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ dockers_v2:
|
|||||||
- netbirdio/netbird
|
- netbirdio/netbird
|
||||||
- ghcr.io/netbirdio/netbird
|
- ghcr.io/netbirdio/netbird
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: client/Dockerfile
|
dockerfile: client/Dockerfile
|
||||||
extra_files:
|
extra_files:
|
||||||
@@ -295,7 +295,7 @@ dockers_v2:
|
|||||||
- netbirdio/relay
|
- netbirdio/relay
|
||||||
- ghcr.io/netbirdio/relay
|
- ghcr.io/netbirdio/relay
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: relay/Dockerfile
|
dockerfile: relay/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -317,7 +317,7 @@ dockers_v2:
|
|||||||
- netbirdio/signal
|
- netbirdio/signal
|
||||||
- ghcr.io/netbirdio/signal
|
- ghcr.io/netbirdio/signal
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: signal/Dockerfile
|
dockerfile: signal/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -339,7 +339,7 @@ dockers_v2:
|
|||||||
- netbirdio/management
|
- netbirdio/management
|
||||||
- ghcr.io/netbirdio/management
|
- ghcr.io/netbirdio/management
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: management/Dockerfile
|
dockerfile: management/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -361,7 +361,7 @@ dockers_v2:
|
|||||||
- netbirdio/upload
|
- netbirdio/upload
|
||||||
- ghcr.io/netbirdio/upload
|
- ghcr.io/netbirdio/upload
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: upload-server/Dockerfile
|
dockerfile: upload-server/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -383,7 +383,7 @@ dockers_v2:
|
|||||||
- netbirdio/netbird-server
|
- netbirdio/netbird-server
|
||||||
- ghcr.io/netbirdio/netbird-server
|
- ghcr.io/netbirdio/netbird-server
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: combined/Dockerfile
|
dockerfile: combined/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -405,7 +405,7 @@ dockers_v2:
|
|||||||
- netbirdio/reverse-proxy
|
- netbirdio/reverse-proxy
|
||||||
- ghcr.io/netbirdio/reverse-proxy
|
- ghcr.io/netbirdio/reverse-proxy
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: proxy/Dockerfile
|
dockerfile: proxy/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -462,9 +462,13 @@ checksum:
|
|||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
- glob: ./release_files/install.sh
|
- glob: ./release_files/install.sh
|
||||||
- glob: ./infrastructure_files/getting-started.sh
|
- glob: ./infrastructure_files/getting-started.sh
|
||||||
|
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||||
|
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||||
|
|
||||||
release:
|
release:
|
||||||
extra_files:
|
extra_files:
|
||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
- glob: ./release_files/install.sh
|
- glob: ./release_files/install.sh
|
||||||
- glob: ./infrastructure_files/getting-started.sh
|
- glob: ./infrastructure_files/getting-started.sh
|
||||||
|
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||||
|
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
|||||||
Username: &username,
|
Username: &username,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.ID(resp.Id), nil
|
return profilemanager.ID(resp.Id), nil
|
||||||
|
|||||||
@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
currUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get current user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
currUser, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get current user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||||
profileName := args[0]
|
profileName := args[0]
|
||||||
|
|
||||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||||
ProfileName: profileName,
|
|
||||||
Username: currUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add profile request: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||||
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||||
}
|
}
|
||||||
|
|
||||||
id := profilemanager.ID(resp.Id)
|
|
||||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||||
|
// and returns the new profile's ID. It is the single entry point for profile
|
||||||
|
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||||
|
// <name>` auto-create path.
|
||||||
|
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||||
|
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||||
|
ProfileName: profileName,
|
||||||
|
Username: username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("add profile failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return profilemanager.ID(resp.Id), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pm := profilemanager.NewProfileManager()
|
// Resolve the active profile's display name via the daemon, which runs
|
||||||
var profName string
|
// as root and can read the per-user profile files. The local profile
|
||||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
// manager only knows the active profile ID, not its display name.
|
||||||
profName = activeProf.Name
|
profName := getActiveProfileName(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getActiveProfileName asks the daemon for the active profile's display
|
||||||
|
// name. The daemon runs as root and can read the per-user profile files to
|
||||||
|
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||||
|
// error so status output degrades gracefully.
|
||||||
|
func getActiveProfileName(ctx context.Context) string {
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.GetProfileName()
|
||||||
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "idle", "connecting", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
|
|||||||
@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
var profileSwitched bool
|
var profileSwitched bool
|
||||||
// switch profile if provided
|
// switch profile if provided
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("switch profile: %v", err)
|
return fmt.Errorf("switch profile: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
|
||||||
return fmt.Errorf("switch profile: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
profileSwitched = true
|
profileSwitched = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// switchOrCreateProfile switches the active profile to the one identified by
|
||||||
|
// handle, creating it first when it does not exist yet. This restores the
|
||||||
|
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||||
|
// missing profile instead of failing.
|
||||||
|
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||||
|
resolvedID, err := switchProfile(ctx, handle, username)
|
||||||
|
if err != nil {
|
||||||
|
st, ok := gstatus.FromError(err)
|
||||||
|
if !ok || st.Code() != codes.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Don't fail immediately on a create error: a concurrent run may
|
||||||
|
// have created the profile between the NotFound above and this
|
||||||
|
// call, in which case the retried switch still succeeds. Only
|
||||||
|
// surface the create error if the switch also fails.
|
||||||
|
_, createErr := createProfile(ctx, handle, username)
|
||||||
|
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||||
|
if createErr != nil {
|
||||||
|
return fmt.Errorf("create profile: %w", createErr)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createProfile dials the daemon and creates a new profile with the given
|
||||||
|
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||||
|
// when a daemon client is already available to reuse the connection.
|
||||||
|
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
//nolint
|
||||||
|
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||||
|
"If the daemon is not running please run: "+
|
||||||
|
"\nnetbird service install \nnetbird service start\n", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||||
|
}
|
||||||
|
|
||||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||||
// override the default profile filepath if provided
|
// override the default profile filepath if provided
|
||||||
if configPath != "" {
|
if configPath != "" {
|
||||||
|
|||||||
@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-startCtx.Done():
|
case <-startCtx.Done():
|
||||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||||
// signal stream while holding the engine mutex and only unblocks on
|
// run loop to tear the engine down, so this cancel() is no longer
|
||||||
// cancellation. Stopping first would deadlock on that mutex.
|
// required to break the deadlock and could be removed. It is kept as a
|
||||||
|
// defensive belt-and-suspenders: cancelling the parent context first
|
||||||
|
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||||
cancel()
|
cancel()
|
||||||
if stopErr := client.Stop(); stopErr != nil {
|
if stopErr := client.Stop(); stopErr != nil {
|
||||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
|
|||||||
|
|
||||||
type ConnectClient struct {
|
type ConnectClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
runCancel context.CancelFunc
|
||||||
|
runExited chan struct{}
|
||||||
|
runOnce sync.Once
|
||||||
|
runStarted atomic.Bool
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -70,8 +75,14 @@ func NewConnectClient(
|
|||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
|
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||||
|
// loop. runCancel is set once at construction, so Stop can call it without
|
||||||
|
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||||
|
runCtx, runCancel := context.WithCancel(ctx)
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: runCtx,
|
||||||
|
runCancel: runCancel,
|
||||||
|
runExited: make(chan struct{}),
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
engineMutex: sync.Mutex{},
|
engineMutex: sync.Mutex{},
|
||||||
@@ -135,6 +146,11 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||||
|
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||||
|
// the loop to finish (and skip the wait if the loop never ran).
|
||||||
|
c.runStarted.Store(true)
|
||||||
|
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -290,7 +306,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
c.runCancel()
|
||||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -410,14 +426,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
// todo: consider to remove this condition. Is not thread safe.
|
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||||
// We should always call Stop(), but we need to verify that it is idempotent
|
|
||||||
if engine.wgInterface != nil {
|
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
@@ -433,12 +445,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.statusRecorder.ClientStart()
|
c.statusRecorder.ClientStart()
|
||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
c.runCancel()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -516,11 +528,9 @@ func (c *ConnectClient) Status() StatusType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Stop() error {
|
func (c *ConnectClient) Stop() error {
|
||||||
engine := c.Engine()
|
c.runCancel()
|
||||||
if engine != nil {
|
if c.runStarted.Load() {
|
||||||
if err := engine.Stop(); err != nil {
|
<-c.runExited
|
||||||
return fmt.Errorf("stop engine: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
|
|||||||
}
|
}
|
||||||
return "[" + strings.Join(parts, ", ") + "]"
|
return "[" + strings.Join(parts, ", ") + "]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||||
|
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||||
|
// advertise EDNS0.
|
||||||
|
func StripOPT(msg *dns.Msg) {
|
||||||
|
if len(msg.Extra) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := msg.Extra[:0]
|
||||||
|
for _, rr := range msg.Extra {
|
||||||
|
if _, ok := rr.(*dns.OPT); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, rr)
|
||||||
|
}
|
||||||
|
msg.Extra = out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||||
|
// the message, if present.
|
||||||
|
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||||
|
opt := msg.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
for _, o := range opt.Option {
|
||||||
|
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||||
|
return ede, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|||||||
@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripOPT(t *testing.T) {
|
||||||
|
rm := &dns.Msg{
|
||||||
|
Extra: []dns.RR{
|
||||||
|
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||||
|
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
StripOPT(rm)
|
||||||
|
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||||
|
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||||
|
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractEDE(t *testing.T) {
|
||||||
|
t.Run("no edns", func(t *testing.T) {
|
||||||
|
_, ok := ExtractEDE(&dns.Msg{})
|
||||||
|
assert.False(t, ok, "message without OPT has no EDE")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("edns without ede", func(t *testing.T) {
|
||||||
|
rm := &dns.Msg{}
|
||||||
|
rm.SetEdns0(4096, false)
|
||||||
|
_, ok := ExtractEDE(rm)
|
||||||
|
assert.False(t, ok, "OPT without EDE option returns false")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with ede", func(t *testing.T) {
|
||||||
|
rm := &dns.Msg{}
|
||||||
|
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||||
|
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||||
|
rm.Extra = append(rm.Extra, opt)
|
||||||
|
|
||||||
|
ede, ok := ExtractEDE(rm)
|
||||||
|
assert.True(t, ok, "EDE option should be found")
|
||||||
|
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||||
|
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -38,11 +39,15 @@ const (
|
|||||||
// defaultWarningDelayBase is the starting grace window before a
|
// defaultWarningDelayBase is the starting grace window before a
|
||||||
// "Nameserver group unreachable" event fires for a group that's
|
// "Nameserver group unreachable" event fires for a group that's
|
||||||
// never been healthy and only has overlay upstreams with no
|
// never been healthy and only has overlay upstreams with no
|
||||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||||
defaultWarningDelayBase = 30 * time.Second
|
// see warningDelay.
|
||||||
|
defaultWarningDelayBase = 60 * time.Second
|
||||||
// warningDelayBonusCap caps the route-count bonus added to the
|
// warningDelayBonusCap caps the route-count bonus added to the
|
||||||
// base grace window. See warningDelayFor.
|
// base grace window. See warningDelay.
|
||||||
warningDelayBonusCap = 30 * time.Second
|
warningDelayBonusCap = 30 * time.Second
|
||||||
|
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||||
|
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||||
|
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||||
)
|
)
|
||||||
|
|
||||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||||
@@ -135,7 +140,7 @@ type DefaultServer struct {
|
|||||||
disableSys bool
|
disableSys bool
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxHandlers []handlerWrapper
|
||||||
localResolver *local.Resolver
|
localResolver *local.Resolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
@@ -199,8 +204,6 @@ type handlerWrapper struct {
|
|||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
|
||||||
|
|
||||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||||
type DefaultServerConfig struct {
|
type DefaultServerConfig struct {
|
||||||
WgInterface WGIface
|
WgInterface WGIface
|
||||||
@@ -289,7 +292,6 @@ func newDefaultServer(
|
|||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: handlerChain,
|
handlerChain: handlerChain,
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
@@ -298,7 +300,7 @@ func newDefaultServer(
|
|||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
mgmtCacheResolver: mgmtCacheResolver,
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||||
warningDelayBase: defaultWarningDelayBase,
|
warningDelayBase: warningDelayBaseFromEnv(),
|
||||||
healthRefresh: make(chan struct{}, 1),
|
healthRefresh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
// Wire the local resolver against the peer status recorder so it can
|
// Wire the local resolver against the peer status recorder so it can
|
||||||
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
|||||||
type routeSettable interface {
|
type routeSettable interface {
|
||||||
setSelectedRoutes(func() route.HAMap)
|
setSelectedRoutes(func() route.HAMap)
|
||||||
}
|
}
|
||||||
for _, entry := range s.dnsMuxMap {
|
for _, entry := range s.dnsMuxHandlers {
|
||||||
if h, ok := entry.handler.(routeSettable); ok {
|
if h, ok := entry.handler.(routeSettable); ok {
|
||||||
h.setSelectedRoutes(selected)
|
h.setSelectedRoutes(selected)
|
||||||
}
|
}
|
||||||
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
|||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
for _, existing := range s.dnsMuxMap {
|
for _, existing := range s.dnsMuxHandlers {
|
||||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
existing.handler.Stop()
|
// The local resolver is a persistent singleton shared by every custom
|
||||||
|
// zone and reused across config updates. Its chain registrations are
|
||||||
|
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||||
|
// context (breaking external CNAME-target resolution) and clears its
|
||||||
|
// records, so it must not be torn down here.
|
||||||
|
if existing.handler != s.localResolver {
|
||||||
|
existing.handler.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.handler.ID()] = update
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxHandlers = muxUpdates
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||||
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||||
|
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||||
|
// non-positive values fall back to defaultWarningDelayBase.
|
||||||
|
func warningDelayBaseFromEnv() time.Duration {
|
||||||
|
val := os.Getenv(envWarningDelay)
|
||||||
|
if val == "" {
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
d, err := time.ParseDuration(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
// warningDelay returns the grace window for the given selected-route
|
// warningDelay returns the grace window for the given selected-route
|
||||||
// count. Scales gently: +1s per 100 routes, capped by
|
// count. Scales gently: +1s per 100 routes, capped by
|
||||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||||
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
|||||||
// in more than one handler.
|
// in more than one handler.
|
||||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||||
for _, entry := range s.dnsMuxMap {
|
for _, entry := range s.dnsMuxHandlers {
|
||||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -104,19 +104,6 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
|
||||||
var srvs []netip.AddrPort
|
|
||||||
for _, srv := range servers {
|
|
||||||
srvs = append(srvs, srv.AddrPort())
|
|
||||||
}
|
|
||||||
u := &upstreamResolverBase{
|
|
||||||
domain: domain.Domain(d),
|
|
||||||
cancel: func() {},
|
|
||||||
}
|
|
||||||
u.addRace(srvs)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyHandler := local.NewResolver()
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
initUpstreamMap registeredHandlerMap
|
initUpstreamMap []handlerWrapper
|
||||||
initLocalZones []nbdns.CustomZone
|
initLocalZones []nbdns.CustomZone
|
||||||
initSerial uint64
|
initSerial uint64
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
inputUpdate nbdns.Config
|
inputUpdate nbdns.Config
|
||||||
shouldFail bool
|
shouldFail bool
|
||||||
expectedUpstreamMap registeredHandlerMap
|
expectedUpstreamMap []handlerWrapper
|
||||||
expectedLocalQs []dns.Question
|
expectedLocalQs []dns.Question
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial Config Should Succeed",
|
name: "Initial Config Should Succeed",
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: nil,
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
dummyHandler.ID(): handlerWrapper{
|
{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityLocal,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: []handlerWrapper{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: &mockHandler{},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"local-resolver": handlerWrapper{
|
{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityLocal,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: nil,
|
||||||
initSerial: 2,
|
initSerial: 2,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: nil,
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: nil,
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: nil,
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
expectedUpstreamMap: []handlerWrapper{{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: []handlerWrapper{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: &mockHandler{},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: nil,
|
||||||
expectedLocalQs: []dns.Question{},
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: []handlerWrapper{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: &mockHandler{},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: nil,
|
||||||
expectedLocalQs: []dns.Question{},
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||||
}
|
}
|
||||||
|
|
||||||
for key := range testCase.expectedUpstreamMap {
|
for _, expected := range testCase.expectedUpstreamMap {
|
||||||
_, found := dnsServer.dnsMuxMap[key]
|
found := false
|
||||||
|
for _, got := range dnsServer.dnsMuxHandlers {
|
||||||
|
if got.domain == expected.domain && got.priority == expected.priority {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if !found {
|
if !found {
|
||||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||||
"id1": handlerWrapper{
|
{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: &local.Resolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
|||||||
func (m *mockService) DeregisterMux(string) {}
|
func (m *mockService) DeregisterMux(string) {}
|
||||||
|
|
||||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||||
baseMatchHandlers := registeredHandlerMap{
|
baseMatchHandlers := []handlerWrapper{
|
||||||
"upstream-group1": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
baseRootHandlers := registeredHandlerMap{
|
baseRootHandlers := []handlerWrapper{
|
||||||
"upstream-root1": {
|
{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-root1",
|
Id: "upstream-root1",
|
||||||
},
|
},
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
"upstream-root2": {
|
{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-root2",
|
Id: "upstream-root2",
|
||||||
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
baseMixedHandlers := registeredHandlerMap{
|
baseMixedHandlers := []handlerWrapper{
|
||||||
"upstream-group1": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
"upstream-other": {
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialHandlers registeredHandlerMap
|
initialHandlers []handlerWrapper
|
||||||
updates []handlerWrapper
|
updates []handlerWrapper
|
||||||
expectedHandlers map[string]string // map[HandlerID]domain
|
expectedHandlers map[string]string // map[HandlerID]domain
|
||||||
description string
|
description string
|
||||||
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
server := &DefaultServer{
|
server := &DefaultServer{
|
||||||
dnsMuxMap: tt.initialHandlers,
|
dnsMuxHandlers: tt.initialHandlers,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
service: &mockService{},
|
service: &mockService{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform the update
|
// Perform the update
|
||||||
server.updateMux(tt.updates)
|
server.updateMux(tt.updates)
|
||||||
|
|
||||||
// Verify the results
|
// Verify the results
|
||||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||||
"Number of handlers after update doesn't match expected")
|
"Number of handlers after update doesn't match expected")
|
||||||
|
|
||||||
// Check each expected handler
|
// Check each expected handler
|
||||||
for id, expectedDomain := range tt.expectedHandlers {
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
var found *handlerWrapper
|
||||||
assert.True(t, exists, "Expected handler %s not found", id)
|
for i := range server.dnsMuxHandlers {
|
||||||
if exists {
|
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||||
assert.Equal(t, expectedDomain, handler.domain,
|
found = &server.dnsMuxHandlers[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||||
|
if found != nil {
|
||||||
|
assert.Equal(t, expectedDomain, found.domain,
|
||||||
"Domain mismatch for handler %s", id)
|
"Domain mismatch for handler %s", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify no unexpected handlers exist
|
// Verify no unexpected handlers exist
|
||||||
for HandlerID := range server.dnsMuxMap {
|
for _, entry := range server.dnsMuxHandlers {
|
||||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the handlerChain state and order
|
// Verify the handlerChain state and order
|
||||||
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
// Verify handler exists in mux
|
// Verify handler exists in mux
|
||||||
foundInMux := false
|
foundInMux := false
|
||||||
for _, muxEntry := range server.dnsMuxMap {
|
for _, muxEntry := range server.dnsMuxHandlers {
|
||||||
if chainEntry.Handler == muxEntry.handler &&
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
chainEntry.Priority == muxEntry.priority &&
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||||
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, foundInMux,
|
assert.True(t, foundInMux,
|
||||||
"Handler in chain not found in dnsMuxMap")
|
"Handler in chain not found in dnsMuxHandlers")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||||
|
// for the given fqdn pattern at the given priority.
|
||||||
|
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||||
|
for _, h := range s.handlerChain.handlers {
|
||||||
|
if h.OrigPattern == pattern && h.Priority == priority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||||
|
// tracks each (handler, domain) registration independently when one handler
|
||||||
|
// serves multiple zones. Every custom zone is served by the same handler
|
||||||
|
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||||
|
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||||
|
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||||
|
// zones onto one entry, leaving removed zones in the chain to answer
|
||||||
|
// authoritatively with no records.
|
||||||
|
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||||
|
// One handler serves every custom zone, mirroring s.localResolver.
|
||||||
|
shared := &mockHandler{Id: "local-resolver"}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two custom zones under the same handler. The surviving zone is registered
|
||||||
|
// last, mirroring the management emission order.
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||||
|
"userzone.test should be registered after the first update")
|
||||||
|
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||||
|
"peerzone.test should be registered after the first update")
|
||||||
|
|
||||||
|
// Remove one zone, keep the other.
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||||
|
"peerzone.test should remain after removing userzone.test")
|
||||||
|
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||||
|
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||||
|
// does not tear down the shared local resolver during reconfiguration. The
|
||||||
|
// resolver is a process-lifetime singleton reused across config updates;
|
||||||
|
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||||
|
// resolution) and clears its records. updateMux must deregister its chain
|
||||||
|
// entries without stopping it. Records surviving a teardown update is the
|
||||||
|
// observable proxy: Stop() would have cleared them.
|
||||||
|
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||||
|
resolver := local.NewResolver()
|
||||||
|
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||||
|
Name: "peer.netbird.cloud.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "10.0.0.1",
|
||||||
|
}))
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
localResolver: resolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Remove the zone. The resolver must survive so its records and lookup
|
||||||
|
// context stay intact for the next registration.
|
||||||
|
server.updateMux(nil)
|
||||||
|
|
||||||
|
var response *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
response = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||||
|
|
||||||
|
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||||
|
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||||
|
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||||
|
}
|
||||||
|
|
||||||
func TestExtraDomains(t *testing.T) {
|
func TestExtraDomains(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
|||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
groups := []*nbdns.NameServerGroup{
|
groups := []*nbdns.NameServerGroup{
|
||||||
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||||
// without spinning up real handlers.
|
// without spinning up real handlers.
|
||||||
type healthStubHandler struct {
|
type healthStubHandler struct {
|
||||||
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||||
activeRoutes: func() route.HAMap { return fx.active },
|
activeRoutes: func() route.HAMap { return fx.active },
|
||||||
warningDelayBase: defaultWarningDelayBase,
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
}
|
}
|
||||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
fx.server.mux.Lock()
|
fx.server.mux.Lock()
|
||||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||||
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return nil },
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
warningDelayBase: 50 * time.Millisecond,
|
warningDelayBase: 50 * time.Millisecond,
|
||||||
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
|||||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
}}
|
}}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
service: NewServiceViaMemory(wgIface),
|
service: NewServiceViaMemory(wgIface),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
extraDomains: map[domain.Domain]int{},
|
extraDomains: map[domain.Domain]int{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
statusRecorder: peer.NewRecorder("mgm"),
|
statusRecorder: peer.NewRecorder("mgm"),
|
||||||
selectedRoutes: func() route.HAMap { return nil },
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
}
|
}
|
||||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||||
// comes up and a query succeeds before the grace window elapses. No
|
// comes up and a query succeeds before the grace window elapses. No
|
||||||
// warning should ever have fired, and no recovery either.
|
// warning should ever have fired, and no recovery either.
|
||||||
|
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
set bool
|
||||||
|
val string
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||||
|
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||||
|
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||||
|
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||||
|
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||||
|
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Setenv(envWarningDelay, tc.val)
|
||||||
|
if !tc.set {
|
||||||
|
os.Unsetenv(envWarningDelay)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||||
fx := newProjTestFixture(t)
|
fx := newProjTestFixture(t)
|
||||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||||
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
|||||||
server := &DefaultServer{
|
server := &DefaultServer{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
warningDelayBase: time.Hour,
|
warningDelayBase: time.Hour,
|
||||||
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
|||||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
|
|||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -443,29 +443,32 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
|||||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||||
|
u.markUpstreamOk(upstream)
|
||||||
|
|
||||||
proto := ""
|
proto := ""
|
||||||
if upstreamProto != nil {
|
if upstreamProto != nil {
|
||||||
proto = upstreamProto.protocol
|
proto = upstreamProto.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
|
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||||
|
// refused zones, transient recursion errors), not reachability
|
||||||
|
// problems: fail over for a better answer but keep the upstream healthy.
|
||||||
if code, ok := nonRetryableEDE(rm); ok {
|
if code, ok := nonRetryableEDE(rm); ok {
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
resutil.StripOPT(rm)
|
||||||
}
|
}
|
||||||
u.markUpstreamOk(upstream)
|
|
||||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||||
}
|
}
|
||||||
reason := dns.RcodeToString[rm.Rcode]
|
reason := dns.RcodeToString[rm.Rcode]
|
||||||
u.markUpstreamFail(upstream, reason)
|
|
||||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
resutil.StripOPT(rm)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.markUpstreamOk(upstream)
|
|
||||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,22 +523,6 @@ func upstreamUDPSize() uint16 {
|
|||||||
return dns.MinMsgSize
|
return dns.MinMsgSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
|
||||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
|
||||||
func stripOPT(rm *dns.Msg) {
|
|
||||||
if len(rm.Extra) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out := rm.Extra[:0]
|
|
||||||
for _, rr := range rm.Extra {
|
|
||||||
if _, ok := rr.(*dns.OPT); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, rr)
|
|
||||||
}
|
|
||||||
rm.Extra = out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||||
|
|||||||
@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
|||||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||||
|
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||||
|
// those are per-question outcomes from a reachable server and must not mark
|
||||||
|
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||||
|
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||||
|
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
respA mockUpstreamResponse
|
||||||
|
respB mockUpstreamResponse
|
||||||
|
wantHealthy bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "both SERVFAIL are reachable",
|
||||||
|
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
wantHealthy: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both REFUSED are reachable",
|
||||||
|
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
wantHealthy: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout marks unhealthy",
|
||||||
|
respA: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
respB: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
wantHealthy: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
a.String(): tc.respA,
|
||||||
|
b.String(): tc.respB,
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{a, b})
|
||||||
|
|
||||||
|
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
|
||||||
|
health := resolver.UpstreamHealth()
|
||||||
|
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||||
|
if tc.wantHealthy {
|
||||||
|
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||||
|
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||||
|
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||||
|
} else {
|
||||||
|
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||||
|
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFormatFailures(t *testing.T) {
|
func TestFormatFailures(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -913,19 +985,6 @@ func TestEDEName(t *testing.T) {
|
|||||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStripOPT(t *testing.T) {
|
|
||||||
rm := &dns.Msg{
|
|
||||||
Extra: []dns.RR{
|
|
||||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
|
||||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
stripOPT(rm)
|
|
||||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
|
||||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
|
||||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|||||||
@@ -26,6 +26,15 @@ import (
|
|||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
const upstreamTimeout = 15 * time.Second
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||||
|
// client can see the reason without inspecting this peer's logs. They live in
|
||||||
|
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||||
|
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||||
|
const (
|
||||||
|
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||||
|
edeNetbirdUpstreamFailure uint16 = 49153
|
||||||
|
)
|
||||||
|
|
||||||
type resolver interface {
|
type resolver interface {
|
||||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
}
|
}
|
||||||
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
|||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
reqHasEdns bool,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reqHasEdns {
|
||||||
|
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||||
|
}
|
||||||
|
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
|||||||
|
|
||||||
return selectedResId, matches
|
return selectedResId, matches
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||||
|
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||||
|
if dnsErr != nil && dnsErr.IsTimeout {
|
||||||
|
return edeNetbirdUpstreamTimeout
|
||||||
|
}
|
||||||
|
return edeNetbirdUpstreamFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||||
|
// It deliberately omits the upstream server address, which may be an internal
|
||||||
|
// resolver and is exposed to any client permitted to use the route; the full
|
||||||
|
// detail stays in the forwarder's local log.
|
||||||
|
func edeText(dnsErr *net.DNSError) string {
|
||||||
|
if dnsErr != nil && dnsErr.IsTimeout {
|
||||||
|
return "netbird forwarder: upstream timeout"
|
||||||
|
}
|
||||||
|
return "netbird forwarder: upstream failure"
|
||||||
|
}
|
||||||
|
|
||||||
|
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||||
|
// creating the OPT pseudo-record if the response does not already carry one.
|
||||||
|
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||||
|
opt := resp.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
opt = resp.IsEdns0()
|
||||||
|
}
|
||||||
|
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
lookupErr error
|
||||||
|
reqEdns bool
|
||||||
|
wantEDE bool
|
||||||
|
wantCode uint16
|
||||||
|
wantTextHas string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "timeout with edns0",
|
||||||
|
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||||
|
reqEdns: true,
|
||||||
|
wantEDE: true,
|
||||||
|
wantCode: edeNetbirdUpstreamTimeout,
|
||||||
|
wantTextHas: "netbird forwarder: upstream timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server failure with edns0",
|
||||||
|
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||||
|
reqEdns: true,
|
||||||
|
wantEDE: true,
|
||||||
|
wantCode: edeNetbirdUpstreamFailure,
|
||||||
|
wantTextHas: "netbird forwarder: upstream failure",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no edns0 in request omits ede",
|
||||||
|
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||||
|
reqEdns: false,
|
||||||
|
wantEDE: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||||
|
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
if tt.reqEdns {
|
||||||
|
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "expected a response")
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||||
|
|
||||||
|
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||||
|
if !tt.wantEDE {
|
||||||
|
assert.False(t, ok, "response must not carry EDE")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.True(t, ok, "response must carry EDE")
|
||||||
|
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||||
|
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||||
|
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
// Test that large UDP responses are truncated with TC bit set
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ const (
|
|||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
|
|
||||||
|
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||||
|
|
||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
@@ -199,6 +201,8 @@ type Engine struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
started bool
|
||||||
|
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
|
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
@@ -279,9 +283,15 @@ func NewEngine(
|
|||||||
services EngineServices,
|
services EngineServices,
|
||||||
mobileDep MobileDependency,
|
mobileDep MobileDependency,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
// The engine is single-use: a fresh instance is built per connection
|
||||||
|
// cycle (see Client.run), so the run context is created once here rather
|
||||||
|
// than in Start.
|
||||||
|
ctx, cancel := context.WithCancel(clientCtx)
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
signal: services.SignalClient,
|
signal: services.SignalClient,
|
||||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||||
mgmClient: services.MgmClient,
|
mgmClient: services.MgmClient,
|
||||||
@@ -314,8 +324,34 @@ func (e *Engine) Stop() error {
|
|||||||
log.Debugf("tried stopping engine that is nil")
|
log.Debugf("tried stopping engine that is nil")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
e.cancel()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
|
|
||||||
|
e.stopLocked()
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
timeout := e.calculateShutdownTimeout()
|
||||||
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||||
|
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopLocked tears down everything Start may have brought up, in the order
|
||||||
|
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||||
|
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||||
|
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||||
|
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||||
|
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||||
|
func (e *Engine) stopLocked() {
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
}
|
}
|
||||||
@@ -366,10 +402,6 @@ func (e *Engine) Stop() error {
|
|||||||
// so dbus and friends don't complain because of a missing interface
|
// so dbus and friends don't complain because of a missing interface
|
||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
|
||||||
e.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
@@ -388,21 +420,6 @@ func (e *Engine) Stop() error {
|
|||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
|
||||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||||
@@ -440,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
|||||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
// The engine is single-use. Reject a duplicate start and a start on an
|
||||||
|
// already-stopped engine (run context cancelled).
|
||||||
|
if e.started {
|
||||||
|
return ErrEngineAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||||
|
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.started = true
|
||||||
|
|
||||||
|
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||||
|
// run context first so goroutines started before the failure (connMgr,
|
||||||
|
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||||
|
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||||
|
// not just what close() covers.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
e.cancel()
|
||||||
|
e.stopLocked()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.cancel != nil {
|
|
||||||
e.cancel()
|
|
||||||
}
|
|
||||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
|
||||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
@@ -485,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("read initial settings: %w", err)
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
@@ -526,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
if err = e.wgInterfaceCreate(); err != nil {
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
e.close()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -547,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.udpMux, err = e.wgInterface.Up()
|
e.udpMux, err = e.wgInterface.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -572,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.acl = acl.NewDefaultManager(e.firewall)
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.Initialize()
|
if err := e.dnsServer.Initialize(); err != nil {
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start(peer.IsForceRelayed())
|
e.srWatcher.Start(peer.IsForceRelayed())
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
if err = e.receiveSignalEvents(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveJobEvents()
|
e.receiveJobEvents()
|
||||||
|
|
||||||
@@ -638,7 +670,6 @@ func (e *Engine) createFirewall() error {
|
|||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("set firewall: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1698,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
}
|
}
|
||||||
|
|
||||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() error {
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
@@ -1769,7 +1800,12 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
e.signal.WaitStreamConnected()
|
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||||
|
e.signal.WaitStreamConnected(e.ctx)
|
||||||
|
if err := e.ctx.Err(); err != nil {
|
||||||
|
return fmt.Errorf("wait for signal stream: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||||
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||||
|
|||||||
@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||||
|
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||||
|
// the original client did not request EDNS0.
|
||||||
|
hadEdns := r.IsEdns0() != nil
|
||||||
|
if !hadEdns {
|
||||||
|
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
}
|
||||||
|
|
||||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||||
|
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||||
|
}
|
||||||
|
if !hadEdns {
|
||||||
|
resutil.StripOPT(reply)
|
||||||
|
}
|
||||||
|
|
||||||
resutil.SetMeta(w, "peer", peerKey)
|
resutil.SetMeta(w, "peer", peerKey)
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ type URLOpener interface {
|
|||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
cfgPath string
|
cfgPath string
|
||||||
}
|
}
|
||||||
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||||
|
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||||
|
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||||
|
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||||
|
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||||
|
// process (decoupled from the network extension), so without this the server
|
||||||
|
// lingers after the user dismisses the browser and the next connect stalls
|
||||||
|
// trying to bind the same port.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
return &Auth{
|
return &Auth{
|
||||||
ctx: context.Background(),
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
cfgPath: cfgPath,
|
cfgPath: cfgPath,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
|||||||
|
|
||||||
// NewAuthWithConfig instantiate Auth based on existing config
|
// NewAuthWithConfig instantiate Auth based on existing config
|
||||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
return &Auth{
|
return &Auth{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||||
|
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||||
|
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||||
|
// safe to call when no login is running.
|
||||||
|
func (a *Auth) Stop() {
|
||||||
|
if a.cancel != nil {
|
||||||
|
a.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
|
|||||||
@@ -993,6 +993,10 @@ func (s *Server) cleanupConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||||
|
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||||
|
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||||
|
// making the run loop the sole owner of engine shutdown.
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
91
combined/cmd/admin.go
Normal file
91
combined/cmd/admin.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newAdminCommands creates the admin command tree with combined-specific resource openers.
|
||||||
|
func newAdminCommands() *cobra.Command {
|
||||||
|
cmd := admincmd.NewCommands(withAdminResources)
|
||||||
|
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
|
||||||
|
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||||
|
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
|
||||||
|
mgmtConfig, err := cfg.ToManagementConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create management config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := idpStorage.Close(); err != nil {
|
||||||
|
log.Debugf("close embedded IdP storage: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// withAdminTokenStore opens only the management store for admin token commands.
|
||||||
|
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
|
||||||
|
return fn(ctx, managementStore)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
cfg, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||||
|
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||||
|
case "postgres":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||||
|
case "mysql":
|
||||||
|
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if file := cfg.Server.Store.File; file != "" {
|
||||||
|
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||||
|
}
|
||||||
|
|
||||||
|
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := managementStore.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, managementStore, cfg)
|
||||||
|
}
|
||||||
@@ -64,7 +64,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||||
|
|
||||||
rootCmd.AddCommand(newTokenCommands())
|
rootCmd.AddCommand(newAdminCommands())
|
||||||
}
|
}
|
||||||
|
|
||||||
func RootCmd() *cobra.Command {
|
func RootCmd() *cobra.Command {
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
|
||||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
|
||||||
func newTokenCommands() *cobra.Command {
|
|
||||||
return tokencmd.NewCommands(withTokenStore)
|
|
||||||
}
|
|
||||||
|
|
||||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
|
||||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
|
||||||
if err := util.InitLog("error", "console"); err != nil {
|
|
||||||
return fmt.Errorf("init log: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
|
||||||
|
|
||||||
cfg, err := LoadConfig(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
|
||||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
|
||||||
case "postgres":
|
|
||||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
|
||||||
case "mysql":
|
|
||||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if file := cfg.Server.Store.File; file != "" {
|
|
||||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
|
||||||
}
|
|
||||||
|
|
||||||
datadir := cfg.Management.DataDir
|
|
||||||
engine := types.Engine(cfg.Management.Store.Engine)
|
|
||||||
|
|
||||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create store: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := s.Close(ctx); err != nil {
|
|
||||||
log.Debugf("close store: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return fn(ctx, s)
|
|
||||||
}
|
|
||||||
616
infrastructure_files/getting-started-enterprise.sh
Executable file
616
infrastructure_files/getting-started-enterprise.sh
Executable file
@@ -0,0 +1,616 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# NetBird Enterprise — Getting Started
|
||||||
|
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||||
|
# embedded identity provider. Owner is created via first-login flow.
|
||||||
|
|
||||||
|
SED_STRIP_PADDING='s/=//g'
|
||||||
|
|
||||||
|
check_docker_compose() {
|
||||||
|
if command -v docker-compose &> /dev/null; then
|
||||||
|
echo "docker-compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if docker compose --help &> /dev/null; then
|
||||||
|
echo "docker compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_openssl() {
|
||||||
|
if ! command -v openssl &> /dev/null; then
|
||||||
|
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_secret() {
|
||||||
|
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_b64_key() {
|
||||||
|
openssl rand -base64 32
|
||||||
|
}
|
||||||
|
|
||||||
|
check_nb_domain() {
|
||||||
|
local domain="$1"
|
||||||
|
if [[ -z "$domain" ]]; then
|
||||||
|
echo "The domain cannot be empty." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||||
|
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||||
|
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||||
|
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
check_domain_resolves() {
|
||||||
|
local domain="$1"
|
||||||
|
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||||
|
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||||
|
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||||
|
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
read_nb_domain() {
|
||||||
|
local value=""
|
||||||
|
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if ! check_nb_domain "$value"; then
|
||||||
|
read_nb_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if ! check_domain_resolves "$value"; then
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||||
|
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||||
|
local confirm=""
|
||||||
|
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||||
|
read -r confirm < /dev/tty
|
||||||
|
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||||
|
read_nb_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_required() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_secret() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -rs value < /dev/tty
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
# read_yes_no "<prompt>" [<default y|n>]
|
||||||
|
read_yes_no() {
|
||||||
|
local prompt="$1"
|
||||||
|
local default="${2:-n}"
|
||||||
|
local hint
|
||||||
|
if [[ "$default" == "y" ]]; then
|
||||||
|
hint="[Y/n]"
|
||||||
|
else
|
||||||
|
hint="[y/N]"
|
||||||
|
fi
|
||||||
|
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||||
|
local ans=""
|
||||||
|
read -r ans < /dev/tty
|
||||||
|
if [[ -z "$ans" ]]; then
|
||||||
|
ans="$default"
|
||||||
|
fi
|
||||||
|
case "$ans" in
|
||||||
|
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||||
|
*) echo "no" ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_postgres() {
|
||||||
|
set +e
|
||||||
|
echo -n "Waiting for postgres to become ready"
|
||||||
|
local counter=1
|
||||||
|
while true; do
|
||||||
|
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $counter -eq 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Postgres is taking too long. Recent logs:"
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
set -e
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment() {
|
||||||
|
check_openssl
|
||||||
|
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||||
|
|
||||||
|
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||||
|
echo "Generated files already exist in $(pwd)."
|
||||||
|
echo "If you want to reinitialize the environment, please remove them first:"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||||
|
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||||
|
echo "Be aware this will remove all data from the database."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "NetBird Enterprise bootstrap"
|
||||||
|
echo ""
|
||||||
|
echo "Traffic flow:"
|
||||||
|
echo " Enables traffic events logging on the management server."
|
||||||
|
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||||
|
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||||
|
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||||
|
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||||
|
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||||
|
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||||
|
|
||||||
|
GHCR_USERNAME="netbirdExtAccess1"
|
||||||
|
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||||
|
|
||||||
|
POSTGRES_USER="netbird"
|
||||||
|
POSTGRES_DB="netbird"
|
||||||
|
POSTGRES_PASSWORD=$(rand_secret)
|
||||||
|
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||||
|
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||||
|
|
||||||
|
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||||
|
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Selected:"
|
||||||
|
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||||
|
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||||
|
echo ""
|
||||||
|
echo "Rendering files into $(pwd) ..."
|
||||||
|
install -m 600 /dev/null .env
|
||||||
|
render_env >> .env
|
||||||
|
render_docker_compose > docker-compose.yml
|
||||||
|
|
||||||
|
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||||
|
fi
|
||||||
|
render_caddyfile > Caddyfile
|
||||||
|
install -m 600 /dev/null config.yaml
|
||||||
|
render_config_yaml >> config.yaml
|
||||||
|
|
||||||
|
echo "Logging in to ghcr.io ..."
|
||||||
|
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||||
|
unset GHCR_TOKEN
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Pulling images ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND pull
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting postgres ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||||
|
sleep 2
|
||||||
|
wait_postgres
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting remaining services ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done."
|
||||||
|
echo ""
|
||||||
|
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||||
|
echo ""
|
||||||
|
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||||
|
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||||
|
echo ""
|
||||||
|
echo "Tail logs:"
|
||||||
|
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Renderers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
render_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# Generated by getting-started-enterprise.sh
|
||||||
|
# Holds all configuration and secrets for the stack. Mode 600.
|
||||||
|
|
||||||
|
# Features (set by the script; don't edit without re-running)
|
||||||
|
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||||
|
|
||||||
|
# Domain
|
||||||
|
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||||
|
|
||||||
|
# Image tags. Default to "latest"
|
||||||
|
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||||
|
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||||
|
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
# License keys
|
||||||
|
EOF
|
||||||
|
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
# Postgres
|
||||||
|
POSTGRES_USER=${POSTGRES_USER}
|
||||||
|
POSTGRES_DB=${POSTGRES_DB}
|
||||||
|
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||||
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||||
|
|
||||||
|
# Relay
|
||||||
|
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||||
|
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||||
|
|
||||||
|
# Datastore encryption
|
||||||
|
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
|
||||||
|
# Dashboard OIDC scopes
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_docker_compose() {
|
||||||
|
render_compose_header
|
||||||
|
render_compose_common
|
||||||
|
render_compose_server
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
render_compose_flow
|
||||||
|
fi
|
||||||
|
render_compose_postgres
|
||||||
|
render_compose_footer
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_header() {
|
||||||
|
cat <<'EOF'
|
||||||
|
x-default: &default
|
||||||
|
restart: unless-stopped
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: '500m'
|
||||||
|
max-file: '2'
|
||||||
|
|
||||||
|
services:
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_common() {
|
||||||
|
cat <<'EOF'
|
||||||
|
caddy:
|
||||||
|
<<: *default
|
||||||
|
image: caddy:2
|
||||||
|
container_name: netbird-caddy
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
- '443:443/udp'
|
||||||
|
- '80:80'
|
||||||
|
volumes:
|
||||||
|
- netbird_caddy_data:/data
|
||||||
|
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||||
|
|
||||||
|
dashboard:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||||
|
container_name: netbird-dashboard
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||||
|
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||||
|
- AUTH_AUDIENCE=netbird-dashboard
|
||||||
|
- AUTH_CLIENT_ID=netbird-dashboard
|
||||||
|
- AUTH_CLIENT_SECRET=
|
||||||
|
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||||
|
- USE_AUTH0=false
|
||||||
|
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||||
|
- AUTH_REDIRECT_URI=/nb-auth
|
||||||
|
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||||
|
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||||
|
- NGINX_SSL_PORT=443
|
||||||
|
- LETSENCRYPT_DOMAIN=
|
||||||
|
- LETSENCRYPT_EMAIL=
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_server() {
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird-server:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||||
|
container_name: netbird-server
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
dashboard:
|
||||||
|
condition: service_started
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
ports:
|
||||||
|
- '3478:3478/udp'
|
||||||
|
volumes:
|
||||||
|
- netbird_data:/var/lib/netbird
|
||||||
|
- ./config.yaml:/etc/netbird/config.yaml
|
||||||
|
command: ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_flow() {
|
||||||
|
cat <<'EOF'
|
||||||
|
nats:
|
||||||
|
<<: *default
|
||||||
|
image: nats:2
|
||||||
|
container_name: netbird-nats
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- netbird_nats_data:/data
|
||||||
|
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||||
|
|
||||||
|
enricher:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||||
|
container_name: netbird-enricher
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
volumes:
|
||||||
|
- netbird_enricher:/var/lib/netbird
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
- NB_DATADIR=/var/lib/netbird
|
||||||
|
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||||
|
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||||
|
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
- NB_FLOW_ADAPTER_TYPE=nats
|
||||||
|
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||||
|
- NB_FLOW_NATS_STREAM=traffic-events
|
||||||
|
- NB_METRICS_PORT=9091
|
||||||
|
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||||
|
|
||||||
|
receiver:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||||
|
container_name: netbird-receiver
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
- NB_FLOW_LISTEN_PORT=80
|
||||||
|
- NB_FLOW_ADAPTER_TYPE=nats
|
||||||
|
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||||
|
- NB_FLOW_NATS_STREAM=traffic-events
|
||||||
|
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_postgres() {
|
||||||
|
cat <<'EOF'
|
||||||
|
postgres:
|
||||||
|
<<: *default
|
||||||
|
image: postgres:17
|
||||||
|
container_name: netbird-postgres
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=${POSTGRES_USER}
|
||||||
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||||
|
- POSTGRES_DB=${POSTGRES_DB}
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 10
|
||||||
|
volumes:
|
||||||
|
- netbird_postgres:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_footer() {
|
||||||
|
cat <<'EOF'
|
||||||
|
volumes:
|
||||||
|
netbird_data:
|
||||||
|
EOF
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird_nats_data:
|
||||||
|
netbird_enricher:
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird_postgres:
|
||||||
|
netbird_caddy_data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_caddyfile() {
|
||||||
|
cat <<'EOF'
|
||||||
|
{
|
||||||
|
servers :80,:443 {
|
||||||
|
protocols h1 h2c h2 h3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(security_headers) {
|
||||||
|
header * {
|
||||||
|
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||||
|
X-Content-Type-Options "nosniff"
|
||||||
|
X-Frame-Options "SAMEORIGIN"
|
||||||
|
X-XSS-Protection "1; mode=block"
|
||||||
|
-Server
|
||||||
|
Referrer-Policy strict-origin-when-cross-origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:80 {
|
||||||
|
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||||
|
}
|
||||||
|
|
||||||
|
{$CADDY_SECURE_DOMAIN}:443 {
|
||||||
|
import security_headers
|
||||||
|
# Signal (gRPC over h2c)
|
||||||
|
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||||
|
# Management (gRPC over h2c + HTTP)
|
||||||
|
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||||
|
reverse_proxy /api/* netbird-server:80
|
||||||
|
reverse_proxy /ws-proxy/* netbird-server:80
|
||||||
|
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||||
|
reverse_proxy /oauth2/* netbird-server:80
|
||||||
|
# Relay (WebSocket multiplexed on the same port)
|
||||||
|
reverse_proxy /relay* netbird-server:80
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<'EOF'
|
||||||
|
# Flow receiver (gRPC over h2c)
|
||||||
|
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<'EOF'
|
||||||
|
# Dashboard
|
||||||
|
reverse_proxy /* dashboard:80
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_config_yaml() {
|
||||||
|
cat <<EOF
|
||||||
|
# NetBird Enterprise server configuration.
|
||||||
|
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||||
|
|
||||||
|
server:
|
||||||
|
listenAddress: ":80"
|
||||||
|
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||||
|
|
||||||
|
metricsPort: 9090
|
||||||
|
healthcheckAddress: ":9000"
|
||||||
|
|
||||||
|
logLevel: "info"
|
||||||
|
logFile: "console"
|
||||||
|
|
||||||
|
# TLS is terminated by Caddy in front; leave this block empty.
|
||||||
|
tls:
|
||||||
|
certFile: ""
|
||||||
|
keyFile: ""
|
||||||
|
letsencrypt:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||||
|
dataDir: "/var/lib/netbird/"
|
||||||
|
|
||||||
|
disableAnonymousMetrics: false
|
||||||
|
disableGeoliteUpdate: false
|
||||||
|
|
||||||
|
auth:
|
||||||
|
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||||
|
localAuthDisabled: false
|
||||||
|
signKeyRefreshEnabled: false
|
||||||
|
dashboardRedirectURIs:
|
||||||
|
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||||
|
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||||
|
cliRedirectURIs:
|
||||||
|
- "http://localhost:53000/"
|
||||||
|
|
||||||
|
store:
|
||||||
|
engine: "postgres"
|
||||||
|
dsn: "${POSTGRES_DSN}"
|
||||||
|
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||||
|
|
||||||
|
activityStore:
|
||||||
|
engine: "postgres"
|
||||||
|
dsn: "${POSTGRES_DSN}"
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
trafficFlow:
|
||||||
|
enabled: true
|
||||||
|
address: "https://${NETBIRD_DOMAIN}:443"
|
||||||
|
interval: "60s"
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment
|
||||||
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
@@ -0,0 +1,638 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# NetBird — community combined → Enterprise combined migration
|
||||||
|
#
|
||||||
|
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||||
|
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||||
|
# existing files. Original docker-compose.yml and config.yaml are never
|
||||||
|
# modified.
|
||||||
|
#
|
||||||
|
# Steps (all optional, asked interactively):
|
||||||
|
# 1. Image swap — replace community images with enterprise cloud images.
|
||||||
|
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||||
|
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||||
|
#
|
||||||
|
# To revert:
|
||||||
|
# docker compose down
|
||||||
|
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||||
|
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||||
|
# # at the end of this script's run.
|
||||||
|
# docker compose up -d
|
||||||
|
|
||||||
|
OVERRIDE_FILE="docker-compose.override.yml"
|
||||||
|
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||||
|
|
||||||
|
check_docker_compose() {
|
||||||
|
if command -v docker-compose &> /dev/null; then
|
||||||
|
echo "docker-compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if docker compose --help &> /dev/null; then
|
||||||
|
echo "docker compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_yq() {
|
||||||
|
if ! command -v yq &> /dev/null; then
|
||||||
|
cat > /dev/stderr <<'EOF'
|
||||||
|
yq is required to parse and update YAML safely.
|
||||||
|
|
||||||
|
macOS: brew install yq
|
||||||
|
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||||
|
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||||
|
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
check_openssl() {
|
||||||
|
if ! command -v openssl &> /dev/null; then
|
||||||
|
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_password() {
|
||||||
|
openssl rand -hex 32
|
||||||
|
}
|
||||||
|
|
||||||
|
read_required() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_secret() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -rs value < /dev/tty
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_yes_no() {
|
||||||
|
local prompt="$1"
|
||||||
|
local default="${2:-n}"
|
||||||
|
local hint
|
||||||
|
if [[ "$default" == "y" ]]; then
|
||||||
|
hint="[Y/n]"
|
||||||
|
else
|
||||||
|
hint="[y/N]"
|
||||||
|
fi
|
||||||
|
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||||
|
local ans=""
|
||||||
|
read -r ans < /dev/tty
|
||||||
|
if [[ -z "$ans" ]]; then
|
||||||
|
ans="$default"
|
||||||
|
fi
|
||||||
|
case "$ans" in
|
||||||
|
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||||
|
*) echo "no" ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Detection — read the operator's existing compose to find service names and
|
||||||
|
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
detect_combined_service() {
|
||||||
|
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_dashboard_service() {
|
||||||
|
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_config_yaml_host_path() {
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_data_volume() {
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_exposed_address() {
|
||||||
|
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_compose_network() {
|
||||||
|
local tag
|
||||||
|
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||||
|
case "$tag" in
|
||||||
|
"!!seq")
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||||
|
;;
|
||||||
|
"!!map")
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "default"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Renderers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Build docker-compose.override.yml from the steps the operator selected.
|
||||||
|
# Service names match what we detected on the operator's side.
|
||||||
|
render_override() {
|
||||||
|
cat <<EOF
|
||||||
|
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||||
|
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||||
|
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||||
|
|
||||||
|
services:
|
||||||
|
${DASHBOARD_SERVICE}:
|
||||||
|
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||||
|
|
||||||
|
${COMBINED_SERVICE}:
|
||||||
|
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
volumes:
|
||||||
|
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||||
|
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||||
|
|
||||||
|
postgres:
|
||||||
|
image: postgres:17
|
||||||
|
container_name: netbird-postgres
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: netbird
|
||||||
|
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||||
|
POSTGRES_DB: netbird
|
||||||
|
volumes:
|
||||||
|
- netbird_postgres:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 20
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
nats:
|
||||||
|
image: nats:2
|
||||||
|
container_name: netbird-nats
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||||
|
volumes:
|
||||||
|
- netbird_nats_data:/data
|
||||||
|
|
||||||
|
flow-enricher:
|
||||||
|
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||||
|
container_name: netbird-flow-enricher
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
NB_DATADIR: /var/lib/netbird
|
||||||
|
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||||
|
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||||
|
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
NB_FLOW_ADAPTER_TYPE: nats
|
||||||
|
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||||
|
NB_FLOW_NATS_STREAM: traffic-events
|
||||||
|
NB_METRICS_PORT: 9091
|
||||||
|
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||||
|
|
||||||
|
flow-receiver:
|
||||||
|
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||||
|
container_name: netbird-flow-receiver
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
depends_on:
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
NB_FLOW_LISTEN_PORT: 80
|
||||||
|
NB_FLOW_ADAPTER_TYPE: nats
|
||||||
|
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||||
|
NB_FLOW_NATS_STREAM: traffic-events
|
||||||
|
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||||
|
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-flow.tls=true
|
||||||
|
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||||
|
- traefik.http.routers.netbird-flow.priority=100
|
||||||
|
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Volume declarations for anything new the override introduced
|
||||||
|
local has_volumes="no"
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
has_volumes="yes"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$has_volumes" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
EOF
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo " netbird_postgres:"
|
||||||
|
fi
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
echo " netbird_nats_data:"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||||
|
# config.yaml. We don't touch the original file.
|
||||||
|
render_enterprise_config() {
|
||||||
|
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
|
||||||
|
yq eval "
|
||||||
|
.server.store.engine = \"postgres\" |
|
||||||
|
.server.store.dsn = \"$pg_dsn\" |
|
||||||
|
.server.activityStore.engine = \"postgres\" |
|
||||||
|
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||||
|
.server.authStore.engine = \"postgres\" |
|
||||||
|
.server.authStore.dsn = \"$pg_dsn\"
|
||||||
|
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
local flow_addr="${NETBIRD_DOMAIN}"
|
||||||
|
yq eval -i "
|
||||||
|
.server.trafficFlow.enabled = true |
|
||||||
|
.server.trafficFlow.address = \"$flow_addr\" |
|
||||||
|
.server.trafficFlow.interval = \"60s\"
|
||||||
|
" "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Execution steps
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
resolve_data_volume() {
|
||||||
|
local short="$1"
|
||||||
|
local actual
|
||||||
|
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||||
|
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||||
|
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||||
|
echo "$actual"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
# Relative bind mount: docker-compose resolves it against the compose
|
||||||
|
# file's directory, but `docker run -v` resolves it against the current
|
||||||
|
# working directory. Normalize to an absolute path so both interpretations
|
||||||
|
# agree (and the printed revert command works from any CWD).
|
||||||
|
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||||
|
local compose_dir
|
||||||
|
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||||
|
(
|
||||||
|
cd "$compose_dir"
|
||||||
|
cd "$(dirname "$short")"
|
||||||
|
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||||
|
echo "$short"
|
||||||
|
}
|
||||||
|
|
||||||
|
backup_sqlite() {
|
||||||
|
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||||
|
mkdir -p "$BACKUP_DIR"
|
||||||
|
local data_volume_actual
|
||||||
|
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||||
|
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||||
|
docker run --rm \
|
||||||
|
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||||
|
-v "${BACKUP_DIR}:/backup" \
|
||||||
|
busybox \
|
||||||
|
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||||
|
local copied
|
||||||
|
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||||
|
if [[ -z "$copied" ]]; then
|
||||||
|
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo " done"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_migrate_store() {
|
||||||
|
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||||
|
echo " done"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
init_migration() {
|
||||||
|
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||||
|
check_yq
|
||||||
|
check_openssl
|
||||||
|
|
||||||
|
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||||
|
|
||||||
|
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||||
|
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||||
|
echo "Migration artifacts already exist in $(pwd):"
|
||||||
|
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||||
|
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "Either you've already migrated, or a previous run was interrupted."
|
||||||
|
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
COMBINED_SERVICE=$(detect_combined_service)
|
||||||
|
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||||
|
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||||
|
DATA_VOLUME=$(detect_data_volume)
|
||||||
|
COMPOSE_NETWORK=$(detect_compose_network)
|
||||||
|
|
||||||
|
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||||
|
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||||
|
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||||
|
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||||
|
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||||
|
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$DATA_VOLUME" ]]; then
|
||||||
|
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Detected existing deployment:"
|
||||||
|
echo " Combined service: $COMBINED_SERVICE"
|
||||||
|
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||||
|
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||||
|
echo " Data volume: $DATA_VOLUME"
|
||||||
|
echo " Network: $COMPOSE_NETWORK"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
local proceed
|
||||||
|
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||||
|
if [[ "$proceed" != "yes" ]]; then
|
||||||
|
echo "Aborted."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 1 — always (this is the point of the script)
|
||||||
|
MIGRATE_IMAGES="yes"
|
||||||
|
echo ""
|
||||||
|
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||||
|
NB_LICENSE_KEY=$(read_secret " License key")
|
||||||
|
GHCR_USERNAME="netbirdExtAccess1"
|
||||||
|
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||||
|
|
||||||
|
# Step 2 — optional
|
||||||
|
echo ""
|
||||||
|
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||||
|
echo " will be backed up automatically. To fully revert later, restore"
|
||||||
|
echo " that backup and delete docker-compose.override.yml +"
|
||||||
|
echo " config.yaml.enterprise."
|
||||||
|
local confirm
|
||||||
|
confirm=$(read_yes_no " Continue?" "y")
|
||||||
|
if [[ "$confirm" != "yes" ]]; then
|
||||||
|
MIGRATE_POSTGRES="no"
|
||||||
|
echo " Skipping Postgres migration."
|
||||||
|
else
|
||||||
|
POSTGRES_PASSWORD=$(rand_password)
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||||
|
echo ""
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
# Auth secret MUST match server.authSecret from config.yaml
|
||||||
|
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||||
|
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||||
|
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||||
|
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||||
|
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||||
|
fi
|
||||||
|
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||||
|
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||||
|
|
||||||
|
# We need the encryption key from the existing config.yaml for the enricher
|
||||||
|
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||||
|
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||||
|
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
ENABLE_FLOW="no"
|
||||||
|
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
apply_changes() {
|
||||||
|
echo ""
|
||||||
|
echo "Writing $OVERRIDE_FILE ..."
|
||||||
|
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||||
|
render_override > "$OVERRIDE_FILE"
|
||||||
|
|
||||||
|
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||||
|
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
render_enterprise_config
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Persist secrets that the override file references via env interpolation.
|
||||||
|
# We write them to a .env file in the current directory; docker compose
|
||||||
|
# picks it up automatically.
|
||||||
|
echo "Writing .env additions (mode 600) ..."
|
||||||
|
local ENV_FILE=".env"
|
||||||
|
touch "$ENV_FILE"
|
||||||
|
chmod 600 "$ENV_FILE"
|
||||||
|
{
|
||||||
|
echo ""
|
||||||
|
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||||
|
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||||
|
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||||
|
fi
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||||
|
fi
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||||
|
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||||
|
fi
|
||||||
|
} >> "$ENV_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Logging in to ghcr.io ..."
|
||||||
|
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||||
|
unset GHCR_TOKEN
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Pulling enterprise images ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND pull
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Stopping existing services (volumes preserved) ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND down
|
||||||
|
|
||||||
|
backup_sqlite
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting Postgres ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||||
|
|
||||||
|
# Wait for healthy
|
||||||
|
local counter=0
|
||||||
|
echo -n "Waiting for Postgres to become ready"
|
||||||
|
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
if [[ $counter -ge 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
|
||||||
|
run_migrate_store
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Bringing up all services ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Migration complete."
|
||||||
|
}
|
||||||
|
|
||||||
|
print_summary() {
|
||||||
|
echo ""
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " Summary"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " Images: swapped to enterprise"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||||
|
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||||
|
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||||
|
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||||
|
echo ""
|
||||||
|
echo " Generated files (next to your docker-compose.yml):"
|
||||||
|
echo " $OVERRIDE_FILE"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo " .env (license key + secrets, mode 600)"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||||
|
echo ""
|
||||||
|
echo " Tail logs:"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||||
|
echo ""
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " To revert"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
# Resolve project-prefixed volume names now (before override is removed).
|
||||||
|
local pg_volume data_volume_actual
|
||||||
|
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||||
|
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||||
|
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||||
|
echo " docker volume rm $pg_volume"
|
||||||
|
echo " # Restore SQLite from the backup created during this run:"
|
||||||
|
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||||
|
fi
|
||||||
|
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
init_migration
|
||||||
|
apply_changes
|
||||||
|
print_summary
|
||||||
89
management/cmd/admin.go
Normal file
89
management/cmd/admin.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter/hook"
|
||||||
|
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||||
|
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var adminDatadir string
|
||||||
|
|
||||||
|
// newAdminCommands creates the admin command tree with management-specific resource openers.
|
||||||
|
func newAdminCommands() *cobra.Command {
|
||||||
|
cmd := admincmd.NewCommands(withAdminResources)
|
||||||
|
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
|
||||||
|
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// withAdminResources initializes logging, loads config, opens the management store
|
||||||
|
// and embedded IdP storage, and calls fn.
|
||||||
|
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||||
|
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
|
||||||
|
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := idpStorage.Close(); err != nil {
|
||||||
|
log.Debugf("close embedded IdP storage: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// withAdminTokenStore opens only the management store for admin token commands.
|
||||||
|
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||||
|
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
|
||||||
|
return fn(ctx, managementStore)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
|
||||||
|
if err := util.InitLog("error", "console"); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
datadir := config.Datadir
|
||||||
|
if adminDatadir != "" {
|
||||||
|
oldDatadir := datadir
|
||||||
|
datadir = adminDatadir
|
||||||
|
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
|
||||||
|
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
|
||||||
|
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
|
||||||
|
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create store: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := managementStore.Close(ctx); err != nil {
|
||||||
|
log.Debugf("close store: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return fn(ctx, managementStore, config)
|
||||||
|
}
|
||||||
441
management/cmd/admin/admin.go
Normal file
441
management/cmd/admin/admin.go
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
|
||||||
|
// Both the management and combined binaries use these commands, each providing
|
||||||
|
// their own opener to handle config loading and storage initialization.
|
||||||
|
package admincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
localConnectorID = "local"
|
||||||
|
dashboardClientID = "netbird-dashboard"
|
||||||
|
cliClientID = "netbird-cli"
|
||||||
|
defaultTOTPAuthenticatorID = "default-totp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Resources contains the storages required by the admin commands.
|
||||||
|
type Resources struct {
|
||||||
|
Store store.Store
|
||||||
|
IDPStorage storage.Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opener initializes command resources from the command context and calls fn.
|
||||||
|
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
|
||||||
|
|
||||||
|
type userSelector struct {
|
||||||
|
email string
|
||||||
|
userID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s userSelector) normalized() userSelector {
|
||||||
|
return userSelector{
|
||||||
|
email: strings.TrimSpace(s.email),
|
||||||
|
userID: strings.TrimSpace(s.userID),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s userSelector) validate() error {
|
||||||
|
s = s.normalized()
|
||||||
|
if (s.email == "") == (s.userID == "") {
|
||||||
|
return fmt.Errorf("provide exactly one of --email or --user-id")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommands creates the admin command tree with the given resource opener.
|
||||||
|
func NewCommands(opener Opener) *cobra.Command {
|
||||||
|
adminCmd := &cobra.Command{
|
||||||
|
Use: "admin",
|
||||||
|
Short: "Self-hosted administrator helpers",
|
||||||
|
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
|
||||||
|
}
|
||||||
|
|
||||||
|
userCmd := &cobra.Command{
|
||||||
|
Use: "user",
|
||||||
|
Short: "Manage local embedded IdP users",
|
||||||
|
}
|
||||||
|
|
||||||
|
var passwordSelector userSelector
|
||||||
|
var password string
|
||||||
|
var passwordFile string
|
||||||
|
passwordCmd := &cobra.Command{
|
||||||
|
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
|
||||||
|
Aliases: []string{"set-password"},
|
||||||
|
Short: "Change a local user's password",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||||
|
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addUserSelectorFlags(passwordCmd, &passwordSelector)
|
||||||
|
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
|
||||||
|
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
|
||||||
|
|
||||||
|
var resetSelector userSelector
|
||||||
|
resetMFACmd := &cobra.Command{
|
||||||
|
Use: "reset-mfa (--email email | --user-id id)",
|
||||||
|
Short: "Reset a local user's MFA enrollment",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||||
|
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addUserSelectorFlags(resetMFACmd, &resetSelector)
|
||||||
|
|
||||||
|
userCmd.AddCommand(passwordCmd, resetMFACmd)
|
||||||
|
|
||||||
|
mfaCmd := &cobra.Command{
|
||||||
|
Use: "mfa",
|
||||||
|
Short: "Manage local MFA for embedded IdP users",
|
||||||
|
}
|
||||||
|
|
||||||
|
enableCmd := &cobra.Command{
|
||||||
|
Use: "enable",
|
||||||
|
Short: "Enable MFA for local embedded IdP users",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||||
|
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
disableCmd := &cobra.Command{
|
||||||
|
Use: "disable",
|
||||||
|
Short: "Disable MFA for local embedded IdP users",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||||
|
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
statusCmd := &cobra.Command{
|
||||||
|
Use: "status",
|
||||||
|
Short: "Show local MFA status",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||||
|
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
|
||||||
|
adminCmd.AddCommand(userCmd, mfaCmd)
|
||||||
|
return adminCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
|
||||||
|
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
|
||||||
|
if cfg == nil || !cfg.Enabled {
|
||||||
|
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
yamlConfig, err := cfg.ToYAMLConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build embedded IdP config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
st, err := yamlConfig.Storage.OpenStorage(logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
|
||||||
|
}
|
||||||
|
return st, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
|
||||||
|
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
|
||||||
|
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
|
||||||
|
if password != "" && passwordFile != "" {
|
||||||
|
return "", fmt.Errorf("provide only one of --password or --password-file")
|
||||||
|
}
|
||||||
|
if passwordFile == "" {
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
var err error
|
||||||
|
if passwordFile == "-" {
|
||||||
|
data, err = io.ReadAll(cmd.InOrStdin())
|
||||||
|
} else {
|
||||||
|
data, err = os.ReadFile(passwordFile)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read password: %w", err)
|
||||||
|
}
|
||||||
|
return strings.TrimRight(string(data), "\r\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
|
||||||
|
if idpStorage == nil {
|
||||||
|
return fmt.Errorf("embedded IdP storage is required")
|
||||||
|
}
|
||||||
|
selector = selector.normalized()
|
||||||
|
if err := selector.validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if password == "" {
|
||||||
|
return fmt.Errorf("password is required")
|
||||||
|
}
|
||||||
|
if err := server.ValidatePassword(password); err != nil {
|
||||||
|
return fmt.Errorf("invalid password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||||
|
old.Hash = hash
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("update password for %s: %w", user.Email, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
|
||||||
|
if idpStorage == nil {
|
||||||
|
return fmt.Errorf("embedded IdP storage is required")
|
||||||
|
}
|
||||||
|
selector = selector.normalized()
|
||||||
|
if err := selector.validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reset := false
|
||||||
|
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||||
|
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
|
||||||
|
old.MFASecrets = map[string]*storage.MFASecret{}
|
||||||
|
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
|
||||||
|
return old, nil
|
||||||
|
})
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if reset {
|
||||||
|
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
|
||||||
|
if resources.Store == nil {
|
||||||
|
return fmt.Errorf("management store is required")
|
||||||
|
}
|
||||||
|
if resources.IDPStorage == nil {
|
||||||
|
return fmt.Errorf("embedded IdP storage is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts := resources.Store.GetAllAccounts(ctx)
|
||||||
|
if len(accounts) != 1 {
|
||||||
|
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &types.Settings{}
|
||||||
|
if accounts[0].Settings != nil {
|
||||||
|
settings = accounts[0].Settings.Copy()
|
||||||
|
}
|
||||||
|
settings.LocalMfaEnabled = enabled
|
||||||
|
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
|
||||||
|
return fmt.Errorf("save local MFA account setting: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
state := "disabled"
|
||||||
|
if enabled {
|
||||||
|
state = "enabled"
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
|
||||||
|
if resources.Store == nil {
|
||||||
|
return fmt.Errorf("management store is required")
|
||||||
|
}
|
||||||
|
if resources.IDPStorage == nil {
|
||||||
|
return fmt.Errorf("embedded IdP storage is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts := resources.Store.GetAllAccounts(ctx)
|
||||||
|
accountStatus := "unknown"
|
||||||
|
if len(accounts) == 1 && accounts[0].Settings != nil {
|
||||||
|
accountStatus = "disabled"
|
||||||
|
if accounts[0].Settings.LocalMfaEnabled {
|
||||||
|
accountStatus = "enabled"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
|
||||||
|
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
|
||||||
|
selector = selector.normalized()
|
||||||
|
if err := selector.validate(); err != nil {
|
||||||
|
return storage.Password{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if selector.email != "" {
|
||||||
|
user, err := idpStorage.GetPassword(ctx, selector.email)
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rawUserID := selector.userID
|
||||||
|
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
|
||||||
|
rawUserID = decodedUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
users, err := idpStorage.ListPasswords(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return storage.Password{}, fmt.Errorf("list local users: %w", err)
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
if user.UserID == rawUserID || user.UserID == selector.userID {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
|
||||||
|
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
|
||||||
|
if err == nil || errors.Is(err, storage.ErrNotFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
|
||||||
|
var mfaChain []string
|
||||||
|
if enabled {
|
||||||
|
mfaChain = []string{defaultTOTPAuthenticatorID}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clientID := range []string{cliClientID, dashboardClientID} {
|
||||||
|
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||||
|
old.MFAChain = mfaChain
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
|
||||||
|
clientIDs := []string{cliClientID, dashboardClientID}
|
||||||
|
enabledCount := 0
|
||||||
|
for _, clientID := range clientIDs {
|
||||||
|
client, err := idpStorage.GetClient(ctx, clientID)
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
|
||||||
|
}
|
||||||
|
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
|
||||||
|
enabledCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch enabledCount {
|
||||||
|
case 0:
|
||||||
|
return "disabled", nil
|
||||||
|
case len(clientIDs):
|
||||||
|
return "enabled", nil
|
||||||
|
default:
|
||||||
|
return "partially enabled", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasAuthenticator(chain []string, authenticatorID string) bool {
|
||||||
|
for _, id := range chain {
|
||||||
|
if id == authenticatorID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
160
management/cmd/admin/admin_test.go
Normal file
160
management/cmd/admin/admin_test.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package admincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
"github.com/dexidp/dex/storage/memory"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestIDPStorage(t *testing.T) storage.Storage {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
|
||||||
|
Email: "user@example.com",
|
||||||
|
Username: "User",
|
||||||
|
UserID: "user-1",
|
||||||
|
Hash: hash,
|
||||||
|
}))
|
||||||
|
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
|
||||||
|
UserID: "user-1",
|
||||||
|
ConnectorID: localConnectorID,
|
||||||
|
MFASecrets: map[string]*storage.MFASecret{
|
||||||
|
defaultTOTPAuthenticatorID: {
|
||||||
|
AuthenticatorID: defaultTOTPAuthenticatorID,
|
||||||
|
Type: "TOTP",
|
||||||
|
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
|
||||||
|
Confirmed: true,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
|
||||||
|
"webauthn": {{CredentialID: []byte("credential")}},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
|
||||||
|
UserID: "user-1",
|
||||||
|
ConnectorID: localConnectorID,
|
||||||
|
Nonce: "nonce",
|
||||||
|
}))
|
||||||
|
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
|
||||||
|
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
|
||||||
|
|
||||||
|
return st
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunChangePassword(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
var out bytes.Buffer
|
||||||
|
|
||||||
|
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, out.String(), "Password updated")
|
||||||
|
|
||||||
|
user, err := st.GetPassword(ctx, "user@example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
|
||||||
|
|
||||||
|
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||||
|
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunChangePasswordValidatesPassword(t *testing.T) {
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid password")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunResetMFA(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
var out bytes.Buffer
|
||||||
|
|
||||||
|
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
|
||||||
|
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, out.String(), "MFA reset")
|
||||||
|
|
||||||
|
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, identity.MFASecrets)
|
||||||
|
require.Empty(t, identity.WebAuthnCredentials)
|
||||||
|
|
||||||
|
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||||
|
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||||
|
old.MFASecrets = nil
|
||||||
|
old.WebAuthnCredentials = nil
|
||||||
|
return old, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
var out bytes.Buffer
|
||||||
|
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, out.String(), "No MFA enrollment found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetIDPClientsMFA(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
|
||||||
|
require.NoError(t, setIDPClientsMFA(ctx, st, true))
|
||||||
|
status, err := idpClientsMFAStatus(ctx, st)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "enabled", status)
|
||||||
|
|
||||||
|
require.NoError(t, setIDPClientsMFA(ctx, st, false))
|
||||||
|
status, err = idpClientsMFAStatus(ctx, st)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "disabled", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserSelectorValidate(t *testing.T) {
|
||||||
|
require.NoError(t, userSelector{email: " user@example.com "}.validate())
|
||||||
|
require.NoError(t, userSelector{userID: "user-1"}.validate())
|
||||||
|
require.Error(t, userSelector{}.validate())
|
||||||
|
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindLocalUserNotFound(t *testing.T) {
|
||||||
|
st := newTestIDPStorage(t)
|
||||||
|
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, strings.Contains(err.Error(), "not found"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePasswordInputFromStdin(t *testing.T) {
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetIn(strings.NewReader("NewPass1!\n"))
|
||||||
|
|
||||||
|
password, err := resolvePasswordInput(cmd, "", "-")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "NewPass1!", password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
|
||||||
|
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -83,7 +83,7 @@ func init() {
|
|||||||
|
|
||||||
rootCmd.AddCommand(migrationCmd)
|
rootCmd.AddCommand(migrationCmd)
|
||||||
|
|
||||||
tc := newTokenCommands()
|
ac := newAdminCommands()
|
||||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||||
rootCmd.AddCommand(tc)
|
rootCmd.AddCommand(ac)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/formatter/hook"
|
|
||||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
var tokenDatadir string
|
|
||||||
|
|
||||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
|
||||||
func newTokenCommands() *cobra.Command {
|
|
||||||
cmd := tokencmd.NewCommands(withTokenStore)
|
|
||||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
|
||||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
|
||||||
if err := util.InitLog("error", "console"); err != nil {
|
|
||||||
return fmt.Errorf("init log: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
|
||||||
|
|
||||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("load config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
datadir := config.Datadir
|
|
||||||
if tokenDatadir != "" {
|
|
||||||
datadir = tokenDatadir
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create store: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := s.Close(ctx); err != nil {
|
|
||||||
log.Debugf("close store: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return fn(ctx, s)
|
|
||||||
}
|
|
||||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetServiceManager(s.ServiceManager())
|
proxyService.SetServiceManager(s.ServiceManager())
|
||||||
proxyService.SetProxyController(s.ServiceProxyController())
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||||
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
|
|||||||
// Manager for users
|
// Manager for users
|
||||||
usersManager users.Manager
|
usersManager users.Manager
|
||||||
|
|
||||||
|
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||||
|
idpManager idp.Manager
|
||||||
|
|
||||||
// Store for one-time authentication tokens
|
// Store for one-time authentication tokens
|
||||||
tokenStore *OneTimeTokenStore
|
tokenStore *OneTimeTokenStore
|
||||||
|
|
||||||
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyServiceServer creates a new proxy service server.
|
// NewProxyServiceServer creates a new proxy service server.
|
||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
|||||||
pkceVerifierStore: pkceStore,
|
pkceVerifierStore: pkceStore,
|
||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
|
idpManager: idpManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
tokenChecker: tokenChecker,
|
tokenChecker: tokenChecker,
|
||||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||||
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||||
|
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||||
// Resolve the principal: when the peer is linked to a user, the human
|
|
||||||
// is the principal so multiple peers owned by the same user share a
|
|
||||||
// single identity. Unlinked peers (machine agents) are their own
|
|
||||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
|
||||||
// tag spend with — user.Email when linked, peer.Name when not.
|
|
||||||
principalID := peer.ID
|
|
||||||
displayIdentity := peer.Name
|
|
||||||
if peer.UserID != "" {
|
|
||||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
|
||||||
principalID = user.Id
|
|
||||||
if user.Email != "" {
|
|
||||||
displayIdentity = user.Email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||||
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||||
|
// user or peer ID, and peer name or user email.
|
||||||
|
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||||
|
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||||
|
// principal so multiple peers owned by the same user share a single
|
||||||
|
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||||
|
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||||
|
// user.Email when linked, peer.Name when not.
|
||||||
|
|
||||||
|
// If the peer isn't associated with a user, return the peer info directly.
|
||||||
|
if peer.UserID == "" {
|
||||||
|
return peer.ID, peer.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||||
|
// if an IdP is available, we gather details on the user from it.
|
||||||
|
principalID := peer.UserID
|
||||||
|
displayIdentity := peer.Name
|
||||||
|
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||||
|
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||||
|
principalID = user.Id
|
||||||
|
if user.Email != "" {
|
||||||
|
displayIdentity = user.Email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// IdP enrichment wins when available — the stored email column is a
|
||||||
|
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||||
|
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||||
|
if s.idpManager != nil {
|
||||||
|
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||||
|
displayIdentity = ud.Email
|
||||||
|
} else if uerr != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return principalID, displayIdentity
|
||||||
|
}
|
||||||
|
|
||||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||||
// groups. Private services authorise against AccessGroups (empty list fails
|
// groups. Private services authorise against AccessGroups (empty list fails
|
||||||
// closed — Validate() rejects that at save time but the RPC is the security
|
// closed — Validate() rejects that at save time but the RPC is the security
|
||||||
|
|||||||
@@ -3,14 +3,19 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockReverseProxyManager struct {
|
type mockReverseProxyManager struct {
|
||||||
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
|||||||
return user, nil, nil
|
return user, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||||
|
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||||
|
// panics if any unexpected method is invoked).
|
||||||
|
type mockTunnelPeersManager struct {
|
||||||
|
peers.Manager
|
||||||
|
peer *peer.Peer
|
||||||
|
peerErr error
|
||||||
|
groups []*types.Group
|
||||||
|
groupsErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||||
|
return m.peer, m.peerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
return m.peer, m.groups, m.groupsErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||||
|
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||||
|
// an IdP that knows nothing about the user.
|
||||||
|
type mockTunnelIdpManager struct {
|
||||||
|
idp.Manager
|
||||||
|
email string
|
||||||
|
hasData bool
|
||||||
|
err error
|
||||||
|
gotCalls int
|
||||||
|
gotMeta []idp.AppMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||||
|
m.gotCalls++
|
||||||
|
m.gotMeta = append(m.gotMeta, meta)
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
if !m.hasData {
|
||||||
|
// This might not be a thing any of the actual IDP implementations do,
|
||||||
|
// i.e. return a nil value with no error, but it seems valuable to test
|
||||||
|
// that behavior here.
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||||
|
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||||
|
// (IdP email -> stored User.Email -> peer.Name).
|
||||||
|
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||||
|
const (
|
||||||
|
domain = "app.example.com"
|
||||||
|
accountID = "account1"
|
||||||
|
peerID = "peer1"
|
||||||
|
peerName = "peer-display-name"
|
||||||
|
userID = "user1"
|
||||||
|
)
|
||||||
|
|
||||||
|
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||||
|
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerUserID string
|
||||||
|
storedUsers map[string]*types.User
|
||||||
|
storedErr error
|
||||||
|
noIdP bool
|
||||||
|
idpEmail string
|
||||||
|
idpHasData bool
|
||||||
|
idpErr error
|
||||||
|
expectEmail string
|
||||||
|
expectUserID string
|
||||||
|
expectIdPHit bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "idp email wins over stored email",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp returns empty email",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp has no data",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpHasData: false,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp errors",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpErr: errors.New("idp unreachable"),
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when no idp manager",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
noIdP: true,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "idp email when stored email is empty",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUserNoEmail,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: map[string]*types.User{},
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unlinked peer uses peer name and never consults idp",
|
||||||
|
peerUserID: "",
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: peerName,
|
||||||
|
expectUserID: peerID,
|
||||||
|
expectIdPHit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUserNoEmail,
|
||||||
|
noIdP: true,
|
||||||
|
expectEmail: peerName,
|
||||||
|
expectUserID: userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
serviceManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||||
|
},
|
||||||
|
peersManager: &mockTunnelPeersManager{
|
||||||
|
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||||
|
},
|
||||||
|
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||||
|
}
|
||||||
|
|
||||||
|
var idpMock *mockTunnelIdpManager
|
||||||
|
if !tt.noIdP {
|
||||||
|
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||||
|
server.idpManager = idpMock
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||||
|
Domain: domain,
|
||||||
|
TunnelIp: "100.64.0.1",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.True(t, resp.GetValid(), "expected access granted")
|
||||||
|
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||||
|
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||||
|
|
||||||
|
if idpMock != nil {
|
||||||
|
if tt.expectIdPHit {
|
||||||
|
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||||
|
require.Len(t, idpMock.gotMeta, 1)
|
||||||
|
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|||||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
|
||||||
proxyService.SetServiceManager(serviceManager)
|
proxyService.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
createTestProxies(t, ctx, testStore)
|
||||||
|
|||||||
@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
usersManager,
|
usersManager,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
}
|
}
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
}
|
}
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
var peer *nbpeer.Peer
|
var peer *nbpeer.Peer
|
||||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||||
var err error
|
var err error
|
||||||
var postureChecks []*posture.Checks
|
|
||||||
var peerGroupIDs []string
|
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
return status.NewPeerLoginExpiredError()
|
return status.NewPeerLoginExpiredError()
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||||
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
|
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||||
if updated {
|
if updated {
|
||||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||||
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
|||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||||
changedPeerIDs := []string{peer.ID}
|
changedPeerIDs := []string{peer.ID}
|
||||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
|
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
}
|
}
|
||||||
|
|
||||||
var peer *nbpeer.Peer
|
var peer *nbpeer.Peer
|
||||||
var shouldStorePeer bool
|
var shouldStorePeer, shouldUpdatePeers bool
|
||||||
var peerGroupIDs []string
|
var peerGroupIDs []string
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||||
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
|
|
||||||
if changed {
|
if changed {
|
||||||
shouldStorePeer = true
|
shouldStorePeer = true
|
||||||
|
shouldUpdatePeers = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer.SSHKey != login.SSHKey {
|
if peer.SSHKey != login.SSHKey {
|
||||||
peer.SSHKey = login.SSHKey
|
peer.SSHKey = login.SSHKey
|
||||||
shouldStorePeer = true
|
shouldStorePeer = true
|
||||||
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
return nil, nil, nil, false, err
|
return nil, nil, nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||||
|
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||||
|
|
||||||
|
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, false, err
|
return nil, nil, nil, false, err
|
||||||
}
|
}
|
||||||
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
|||||||
return nil, nil, nil, false, err
|
return nil, nil, nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isStatusChanged || shouldStorePeer {
|
if shouldUpdatePeers {
|
||||||
changedPeerIDs := []string{peer.ID}
|
changedPeerIDs := []string{peer.ID}
|
||||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||||
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
|||||||
return network, nil, false, nil
|
return network, nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, false, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, false, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
|||||||
return network, postureChecks, enableSSH, nil
|
return network, postureChecks, enableSSH, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
|
||||||
if err != nil {
|
for _, peerID := range peerGroupIDs {
|
||||||
return false, err
|
groupIDsMap[peerID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
|
||||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
|
||||||
for _, g := range peerGroups {
|
|
||||||
peerGroupIDs[g.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks for the peer.
|
// getPeerPostureChecks returns the posture checks for the peer.
|
||||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(policies) == 0 {
|
if len(policies) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
|
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
||||||
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
|
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sourceGroup := range rule.Sources {
|
for _, sourceGroup := range rule.Sources {
|
||||||
group, ok := sourceGroups[sourceGroup]
|
if slices.Contains(peerGroupIDs, sourceGroup) {
|
||||||
if !ok {
|
return policy.SourcePostureChecks
|
||||||
return nil, fmt.Errorf("failed to check peer in policy source group")
|
|
||||||
}
|
|
||||||
|
|
||||||
if slices.Contains(group.Peers, peerID) {
|
|
||||||
return policy.SourcePostureChecks, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||||
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
|
return len(metaDiff(p, other)) == 0
|
||||||
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
|
|
||||||
})
|
|
||||||
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
|
|
||||||
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
|
|
||||||
})
|
|
||||||
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
|
|
||||||
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
|
|
||||||
})
|
|
||||||
if !equalNetworkAddresses {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(p.Files, func(i, j int) bool {
|
|
||||||
return p.Files[i].Path < p.Files[j].Path
|
|
||||||
})
|
|
||||||
sort.Slice(other.Files, func(i, j int) bool {
|
|
||||||
return other.Files[i].Path < other.Files[j].Path
|
|
||||||
})
|
|
||||||
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
|
|
||||||
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
|
|
||||||
})
|
|
||||||
if !equalFiles {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.Hostname == other.Hostname &&
|
|
||||||
p.GoOS == other.GoOS &&
|
|
||||||
p.Kernel == other.Kernel &&
|
|
||||||
p.KernelVersion == other.KernelVersion &&
|
|
||||||
p.Core == other.Core &&
|
|
||||||
p.Platform == other.Platform &&
|
|
||||||
p.OS == other.OS &&
|
|
||||||
p.OSVersion == other.OSVersion &&
|
|
||||||
p.WtVersion == other.WtVersion &&
|
|
||||||
p.UIVersion == other.UIVersion &&
|
|
||||||
p.SystemSerialNumber == other.SystemSerialNumber &&
|
|
||||||
p.SystemProductName == other.SystemProductName &&
|
|
||||||
p.SystemManufacturer == other.SystemManufacturer &&
|
|
||||||
p.Environment.Cloud == other.Environment.Cloud &&
|
|
||||||
p.Environment.Platform == other.Environment.Platform &&
|
|
||||||
p.Flags.isEqual(other.Flags) &&
|
|
||||||
capabilitiesEqual(p.Capabilities, other.Capabilities)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PeerSystemMeta) isEmpty() bool {
|
func (p PeerSystemMeta) isEmpty() bool {
|
||||||
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
|
|||||||
|
|
||||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||||
// returns true if meta was updated, false otherwise
|
// returns true if meta was updated, false otherwise
|
||||||
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
|
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||||
if meta.isEmpty() {
|
if meta.isEmpty() {
|
||||||
return updated, versionChanged
|
return updated, versionChanged
|
||||||
}
|
}
|
||||||
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
|
|||||||
meta.UIVersion = p.Meta.UIVersion
|
meta.UIVersion = p.Meta.UIVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Meta.isEqual(meta) {
|
oldVersion := p.Meta.WtVersion
|
||||||
return updated, versionChanged
|
|
||||||
|
diff := metaDiff(p.Meta, meta)
|
||||||
|
if len(diff) != 0 {
|
||||||
|
p.Meta = meta
|
||||||
|
updated = true
|
||||||
}
|
}
|
||||||
p.Meta = meta
|
|
||||||
updated = true
|
versionInfo := ""
|
||||||
|
if versionChanged {
|
||||||
|
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(diff) > 0 || versionChanged {
|
||||||
|
log.WithContext(ctx).
|
||||||
|
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
return updated, versionChanged
|
return updated, versionChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// metaDiff returns a human-readable list of the fields that differ between the
|
||||||
|
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||||
|
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||||
|
// diff, so the log line can never disagree with the change decision. Slices are
|
||||||
|
// cloned before sorting, so callers' meta is not mutated.
|
||||||
|
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||||
|
var diff []string
|
||||||
|
add := func(field string, oldVal, newVal any) {
|
||||||
|
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldMeta.Hostname != newMeta.Hostname {
|
||||||
|
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||||
|
}
|
||||||
|
if oldMeta.GoOS != newMeta.GoOS {
|
||||||
|
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||||
|
}
|
||||||
|
if oldMeta.Kernel != newMeta.Kernel {
|
||||||
|
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||||
|
}
|
||||||
|
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||||
|
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||||
|
}
|
||||||
|
if oldMeta.Core != newMeta.Core {
|
||||||
|
add("core", oldMeta.Core, newMeta.Core)
|
||||||
|
}
|
||||||
|
if oldMeta.Platform != newMeta.Platform {
|
||||||
|
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||||
|
}
|
||||||
|
if oldMeta.OS != newMeta.OS {
|
||||||
|
add("os", oldMeta.OS, newMeta.OS)
|
||||||
|
}
|
||||||
|
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||||
|
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||||
|
}
|
||||||
|
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||||
|
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||||
|
}
|
||||||
|
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||||
|
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||||
|
}
|
||||||
|
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||||
|
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||||
|
}
|
||||||
|
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||||
|
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||||
|
}
|
||||||
|
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||||
|
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||||
|
}
|
||||||
|
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||||
|
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||||
|
}
|
||||||
|
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||||
|
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||||
|
}
|
||||||
|
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||||
|
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||||
|
}
|
||||||
|
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||||
|
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||||
|
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||||
|
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||||
|
}
|
||||||
|
|
||||||
|
return diff
|
||||||
|
}
|
||||||
|
|
||||||
|
// sameMultiset reports whether two slices contain the same elements with the
|
||||||
|
// same multiplicity, ignoring order. The element type is the comparison key, so
|
||||||
|
// every field participates in equality.
|
||||||
|
func sameMultiset[T comparable](a, b []T) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
counts := make(map[T]int, len(a))
|
||||||
|
for _, v := range a {
|
||||||
|
counts[v]++
|
||||||
|
}
|
||||||
|
for _, v := range b {
|
||||||
|
counts[v]--
|
||||||
|
if counts[v] == 0 {
|
||||||
|
delete(counts, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(counts) == 0
|
||||||
|
}
|
||||||
|
|
||||||
// GetLastLogin returns the last login time of the peer.
|
// GetLastLogin returns the last login time of the peer.
|
||||||
func (p *Peer) GetLastLogin() time.Time {
|
func (p *Peer) GetLastLogin() time.Time {
|
||||||
if p.LastLogin != nil {
|
if p.LastLogin != nil {
|
||||||
|
|||||||
113
management/server/peer/peer_metadiff_test.go
Normal file
113
management/server/peer/peer_metadiff_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
|
||||||
|
// map 1:1 to a single diff entry. Today the only such field is Environment, which
|
||||||
|
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
|
||||||
|
// entry beyond its single struct field. If you teach metaDiff to explode another
|
||||||
|
// field into N entries, bump this by N-1; if you collapse a field, lower it.
|
||||||
|
const metaDiffExtraEntries = 1
|
||||||
|
|
||||||
|
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
|
||||||
|
// values and diffs it against the zero value, then asserts metaDiff emits exactly
|
||||||
|
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
|
||||||
|
//
|
||||||
|
// The expected count is derived from the struct via reflection, so adding a field
|
||||||
|
// to PeerSystemMeta raises the expectation automatically — but the actual diff
|
||||||
|
// only grows if metaDiff was taught to compare the new field. A mismatch means
|
||||||
|
// someone changed the struct without updating metaDiff (or this test's
|
||||||
|
// extra-entry accounting), which is exactly what we want to catch.
|
||||||
|
func TestMetaDiff_CoversAllFields(t *testing.T) {
|
||||||
|
var full PeerSystemMeta
|
||||||
|
exported := populateAll(t, reflect.ValueOf(&full).Elem())
|
||||||
|
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
|
||||||
|
|
||||||
|
diff := metaDiff(PeerSystemMeta{}, full)
|
||||||
|
|
||||||
|
require.Len(t, diff, exported+metaDiffExtraEntries,
|
||||||
|
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
|
||||||
|
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
|
||||||
|
"diff was: %v", diff)
|
||||||
|
|
||||||
|
require.False(t, full.isEqual(PeerSystemMeta{}),
|
||||||
|
"isEqual must report a fully-populated meta as different from the zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
|
||||||
|
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
|
||||||
|
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
|
||||||
|
// compare would not change the diff count. This flips each Flags field on its own
|
||||||
|
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
|
||||||
|
// fails here.
|
||||||
|
func TestFlags_isEqualChecksEveryField(t *testing.T) {
|
||||||
|
typ := reflect.TypeOf(Flags{})
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
f := typ.Field(i)
|
||||||
|
require.Equal(t, reflect.Bool, f.Type.Kind(),
|
||||||
|
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
|
||||||
|
|
||||||
|
var a, b Flags
|
||||||
|
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
|
||||||
|
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateAll sets every exported field of the struct to a deterministic non-zero
|
||||||
|
// value, recursing into nested structs and the element type of struct slices so
|
||||||
|
// that each leaf differs from zero. It returns the number of exported fields on
|
||||||
|
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
|
||||||
|
// settable exported fields and is comparable with ==).
|
||||||
|
func populateAll(t *testing.T, v reflect.Value) int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
typ := v.Type()
|
||||||
|
exported := 0
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
f := typ.Field(i)
|
||||||
|
if f.PkgPath != "" { // unexported
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exported++
|
||||||
|
setNonZero(t, v.Field(i))
|
||||||
|
}
|
||||||
|
return exported
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
|
||||||
|
// recursing into nested structs and populating one element of slice fields.
|
||||||
|
func setNonZero(t *testing.T, field reflect.Value) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
|
||||||
|
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
field.SetString("non-zero")
|
||||||
|
case reflect.Bool:
|
||||||
|
field.SetBool(true)
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
field.SetInt(7)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
field.SetUint(7)
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
field.SetFloat(7)
|
||||||
|
case reflect.Struct:
|
||||||
|
populateAll(t, field)
|
||||||
|
case reflect.Slice:
|
||||||
|
s := reflect.MakeSlice(field.Type(), 1, 1)
|
||||||
|
setNonZero(t, s.Index(0))
|
||||||
|
field.Set(s)
|
||||||
|
default:
|
||||||
|
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1847,12 +1847,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
|
|||||||
|
|
||||||
const minPasswordLength = 8
|
const minPasswordLength = 8
|
||||||
|
|
||||||
// validatePassword checks password strength requirements:
|
// validatePassword checks password strength requirements.
|
||||||
|
func validatePassword(password string) error {
|
||||||
|
return ValidatePassword(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePassword checks password strength requirements:
|
||||||
// - Minimum 8 characters
|
// - Minimum 8 characters
|
||||||
// - At least 1 digit
|
// - At least 1 digit
|
||||||
// - At least 1 uppercase letter
|
// - At least 1 uppercase letter
|
||||||
// - At least 1 special character
|
// - At least 1 special character
|
||||||
func validatePassword(password string) error {
|
func ValidatePassword(password string) error {
|
||||||
if len(password) < minPasswordLength {
|
if len(password) < minPasswordLength {
|
||||||
return errors.New("password must be at least 8 characters long")
|
return errors.New("password must be at least 8 characters long")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
|||||||
oidcConfig,
|
oidcConfig,
|
||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
|
nil,
|
||||||
realProxyManager,
|
realProxyManager,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
|||||||
oidcConfig,
|
oidcConfig,
|
||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
|
nil,
|
||||||
proxyManager,
|
proxyManager,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
|
|||||||
|
|
||||||
fetch_all_tags() {
|
fetch_all_tags() {
|
||||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||||
|
grep -iv 'rc' | \
|
||||||
sed 's/.*\/v//' | \
|
sed 's/.*\/v//' | \
|
||||||
sort -u -V
|
sort -u -V
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ fetch_current_ports_version() {
|
|||||||
fetch_all_tags() {
|
fetch_all_tags() {
|
||||||
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
# Fetch tags from GitHub tags page (no rate limiting, no auth needed)
|
||||||
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
|
||||||
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \
|
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
|
||||||
|
grep -iv 'rc' | \
|
||||||
sed 's/.*\/v//' | \
|
sed 's/.*\/v//' | \
|
||||||
sort -u -V
|
sort -u -V
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type Client interface {
|
|||||||
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||||
Ready() bool
|
Ready() bool
|
||||||
IsHealthy() bool
|
IsHealthy() bool
|
||||||
WaitStreamConnected()
|
WaitStreamConnected(context.Context)
|
||||||
SendToStream(msg *proto.EncryptedMessage) error
|
SendToStream(msg *proto.EncryptedMessage) error
|
||||||
Send(msg *proto.Message) error
|
Send(msg *proto.Message) error
|
||||||
SetOnReconnectedListener(func())
|
SetOnReconnectedListener(func())
|
||||||
|
|||||||
@@ -65,7 +65,10 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
clientA.WaitStreamConnected()
|
ctxA, cancelA := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancelA()
|
||||||
|
clientA.WaitStreamConnected(ctxA)
|
||||||
|
Expect(clientA.StreamConnected()).To(BeTrue())
|
||||||
|
|
||||||
// connect PeerB to Signal
|
// connect PeerB to Signal
|
||||||
keyB, _ := wgtypes.GenerateKey()
|
keyB, _ := wgtypes.GenerateKey()
|
||||||
@@ -91,7 +94,10 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientB.WaitStreamConnected()
|
ctxB, cancelB := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancelB()
|
||||||
|
clientB.WaitStreamConnected(ctxB)
|
||||||
|
Expect(clientB.StreamConnected()).To(BeTrue())
|
||||||
|
|
||||||
// PeerA initiates ping-pong
|
// PeerA initiates ping-pong
|
||||||
err := clientA.Send(&sigProto.Message{
|
err := clientA.Send(&sigProto.Message{
|
||||||
@@ -129,8 +135,10 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
client.WaitStreamConnected()
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
Expect(client).NotTo(BeNil())
|
defer cancel()
|
||||||
|
client.WaitStreamConnected(ctx)
|
||||||
|
Expect(client.StreamConnected()).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -246,15 +246,6 @@ func (c *GrpcClient) notifyStreamConnected() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
if c.connectedCh == nil {
|
|
||||||
c.connectedCh = make(chan struct{})
|
|
||||||
}
|
|
||||||
return c.connectedCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||||
c.stream = nil
|
c.stream = nil
|
||||||
|
|
||||||
@@ -310,14 +301,24 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||||
func (c *GrpcClient) WaitStreamConnected() {
|
func (c *GrpcClient) WaitStreamConnected(ctx context.Context) {
|
||||||
|
// Check the status and obtain the wait channel atomically: otherwise
|
||||||
|
// notifyStreamConnected could flip the status and close/clear the channel
|
||||||
|
// between the check and the channel creation, leaving us waiting forever on
|
||||||
|
// a stale channel.
|
||||||
|
c.mux.Lock()
|
||||||
if c.status == StreamConnected {
|
if c.status == StreamConnected {
|
||||||
|
c.mux.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if c.connectedCh == nil {
|
||||||
|
c.connectedCh = make(chan struct{})
|
||||||
|
}
|
||||||
|
ch := c.connectedCh
|
||||||
|
c.mux.Unlock()
|
||||||
|
|
||||||
ch := c.getStreamStatusChan()
|
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
case <-ch:
|
case <-ch:
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (sm *MockClient) Ready() bool {
|
|||||||
return sm.ReadyFunc()
|
return sm.ReadyFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *MockClient) WaitStreamConnected() {
|
func (sm *MockClient) WaitStreamConnected(context.Context) {
|
||||||
if sm.WaitStreamConnectedFunc == nil {
|
if sm.WaitStreamConnectedFunc == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
|
|||||||
|
|
||||||
streamReady := make(chan struct{})
|
streamReady := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
client.WaitStreamConnected()
|
client.WaitStreamConnected(ctx)
|
||||||
close(streamReady)
|
close(streamReady)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Peer struct {
|
|||||||
|
|
||||||
// a gRpc connection stream to the Peer
|
// a gRpc connection stream to the Peer
|
||||||
Stream proto.SignalExchange_ConnectStreamServer
|
Stream proto.SignalExchange_ConnectStreamServer
|
||||||
|
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
|
||||||
|
// the same ServerStream, and a peer can be the target of many senders at
|
||||||
|
// once.
|
||||||
|
sendMu sync.Mutex
|
||||||
|
|
||||||
// registration time
|
// registration time
|
||||||
RegisteredAt time.Time
|
RegisteredAt time.Time
|
||||||
@@ -33,6 +37,13 @@ type Peer struct {
|
|||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send writes a message to the peer's stream, serializing concurrent senders.
|
||||||
|
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
|
||||||
|
p.sendMu.Lock()
|
||||||
|
defer p.sendMu.Unlock()
|
||||||
|
return p.Stream.Send(msg)
|
||||||
|
}
|
||||||
|
|
||||||
// NewPeer creates a new instance of a connected Peer
|
// NewPeer creates a new instance of a connected Peer
|
||||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
|
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
|
||||||
return &Peer{
|
return &Peer{
|
||||||
|
|||||||
67
signal/server/concurrent_send_test.go
Normal file
67
signal/server/concurrent_send_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
"github.com/netbirdio/netbird/signal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// concurrencyCheckStream records the maximum number of Send calls in flight at
|
||||||
|
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
|
||||||
|
// server must never have more than one in flight per peer.
|
||||||
|
type concurrencyCheckStream struct {
|
||||||
|
proto.SignalExchange_ConnectStreamServer
|
||||||
|
ctx context.Context
|
||||||
|
inflight atomic.Int32
|
||||||
|
maxSeen atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
|
||||||
|
n := s.inflight.Add(1)
|
||||||
|
for {
|
||||||
|
old := s.maxSeen.Load()
|
||||||
|
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Widen the window so overlapping callers are reliably observed.
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
s.inflight.Add(-1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
|
||||||
|
|
||||||
|
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
|
||||||
|
// same peer never call Stream.Send concurrently, which would violate the gRPC
|
||||||
|
// ServerStream contract.
|
||||||
|
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
|
||||||
|
s, err := NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const peerID = "peerX"
|
||||||
|
stream := &concurrencyCheckStream{ctx: context.Background()}
|
||||||
|
_, cancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
|
||||||
|
}
|
||||||
@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
|
|||||||
sendResultChan := make(chan error, 1)
|
sendResultChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case sendResultChan <- dstPeer.Stream.Send(msg):
|
case sendResultChan <- dstPeer.Send(msg):
|
||||||
return
|
return
|
||||||
case <-dstPeer.Stream.Context().Done():
|
case <-dstPeer.Stream.Context().Done():
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user