mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-30 11:49:56 +00:00
Compare commits
54 Commits
fix/ipv6-a
...
netmap_pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4988b6726e | ||
|
|
2552830184 | ||
|
|
3b8fc688f4 | ||
|
|
d82d62e818 | ||
|
|
0bf964dad7 | ||
|
|
297dcb3e24 | ||
|
|
bc22926fe0 | ||
|
|
d3f2ef9adb | ||
|
|
5bec1e8f03 | ||
|
|
74bb5c613e | ||
|
|
29dde908ae | ||
|
|
2d7b309004 | ||
|
|
5968cff242 | ||
|
|
cf43841b86 | ||
|
|
739e36a313 | ||
|
|
2bb5421631 | ||
|
|
998ade6e6d | ||
|
|
62f5467cd8 | ||
|
|
1b29995ece | ||
|
|
fd96b8c12f | ||
|
|
6dd6c3f398 | ||
|
|
d1422dcf09 | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a | ||
|
|
6c26178ad5 | ||
|
|
af3b7e4497 | ||
|
|
e84f6527f7 | ||
|
|
ac9529ea8c | ||
|
|
f736ef9647 | ||
|
|
cf58bf1ba9 | ||
|
|
522b8ed969 | ||
|
|
c9e99659ea | ||
|
|
58c79f5878 | ||
|
|
15a0504fb1 | ||
|
|
883a1a8961 | ||
|
|
54192a94b7 | ||
|
|
8511687270 | ||
|
|
35b465fa4a | ||
|
|
fb87f751a5 | ||
|
|
679c7182a4 | ||
|
|
8c031ea6f0 | ||
|
|
60a9544656 | ||
|
|
d3710d4bb2 |
@@ -20,7 +20,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -59,12 +59,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: true
|
cache: true
|
||||||
|
|||||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||||
|
|||||||
10
.github/workflows/golang-test-darwin.yml
vendored
10
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,18 +16,18 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -45,10 +45,10 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
|
|||||||
24
.github/workflows/golang-test-freebsd.yml
vendored
24
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ jobs:
|
|||||||
id: test
|
id: test
|
||||||
env:
|
env:
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
@@ -48,14 +48,14 @@ jobs:
|
|||||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -tags privileged -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -tags privileged -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -tags privileged -timeout 1m -failfast ./encryption/...
|
||||||
time go test -timeout 1m -failfast ./formatter/...
|
time go test -tags privileged -timeout 1m -failfast ./formatter/...
|
||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -tags privileged -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -tags privileged -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -tags privileged -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
82
.github/workflows/golang-test-linux.yml
vendored
82
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
management: ${{ steps.filter.outputs.management }}
|
management: ${{ steps.filter.outputs.management }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- 'management/**'
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -119,12 +119,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -135,7 +135,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -158,11 +158,11 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -175,12 +175,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -192,7 +192,7 @@ jobs:
|
|||||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache-restore
|
id: cache-restore
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -229,7 +229,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -246,12 +246,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -266,7 +266,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -290,7 +290,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -306,12 +306,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -325,7 +325,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -347,7 +347,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -363,12 +363,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -383,7 +383,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -407,7 +407,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -424,12 +424,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -440,7 +440,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -484,7 +484,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
@@ -529,12 +529,12 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -545,7 +545,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -579,10 +579,11 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
@@ -623,12 +624,12 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -639,7 +640,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -673,12 +674,13 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/server/http/...
|
-timeout 20m ./management/server/http/...
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
name: "Management / Integration"
|
name: "Management / Integration"
|
||||||
@@ -692,12 +694,12 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -708,7 +710,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -734,7 +736,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
slug: netbirdio/netbird
|
slug: netbirdio/netbird
|
||||||
|
|||||||
8
.github/workflows/golang-test-windows.yml
vendored
8
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
@@ -35,7 +35,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
|
|||||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: codespell
|
- name: codespell
|
||||||
@@ -40,7 +40,7 @@ jobs:
|
|||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Check for duplicate constants
|
- name: Check for duplicate constants
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|||||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|||||||
12
.github/workflows/mobile-build-validation.yml
vendored
12
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
@@ -28,13 +28,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
cmdline-tools-version: 8512546
|
cmdline-tools-version: 8512546
|
||||||
- name: Setup Java
|
- name: Setup Java
|
||||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||||
with:
|
with:
|
||||||
java-version: "11"
|
java-version: "11"
|
||||||
distribution: "adopt"
|
distribution: "adopt"
|
||||||
- name: NDK Cache
|
- name: NDK Cache
|
||||||
id: ndk-cache
|
id: ndk-cache
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: /usr/local/lib/android/sdk/ndk
|
path: /usr/local/lib/android/sdk/ndk
|
||||||
key: ndk-cache-23.1.7779620
|
key: ndk-cache-23.1.7779620
|
||||||
@@ -54,11 +54,11 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
|
|||||||
38
.github/workflows/release.yml
vendored
38
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ jobs:
|
|||||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
env:
|
env:
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
@@ -135,7 +135,7 @@ jobs:
|
|||||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -166,12 +166,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -186,9 +186,9 @@ jobs:
|
|||||||
- name: check git status
|
- name: check git status
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
@@ -221,7 +221,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --clean ${{ env.flags }}
|
args: release --clean ${{ env.flags }}
|
||||||
@@ -347,7 +347,7 @@ jobs:
|
|||||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -374,12 +374,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -420,7 +420,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||||
@@ -464,17 +464,17 @@ jobs:
|
|||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -488,7 +488,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||||
@@ -522,7 +522,7 @@ jobs:
|
|||||||
downloadPath: '${{ github.workspace }}\temp'
|
downloadPath: '${{ github.workspace }}\temp'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -534,13 +534,13 @@ jobs:
|
|||||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
|
|
||||||
- name: Download release artifacts
|
- name: Download release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: release
|
path: release
|
||||||
|
|
||||||
- name: Download UI release artifacts
|
- name: Download UI release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: release-ui
|
path: release-ui
|
||||||
|
|||||||
14
.github/workflows/test-infrastructure-files.yml
vendored
14
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,17 +68,17 @@ jobs:
|
|||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -207,7 +207,7 @@ jobs:
|
|||||||
- name: Build management docker image
|
- name: Build management docker image
|
||||||
working-directory: management
|
working-directory: management
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/management:latest .
|
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: Build signal binary
|
- name: Build signal binary
|
||||||
working-directory: signal
|
working-directory: signal
|
||||||
@@ -216,7 +216,7 @@ jobs:
|
|||||||
- name: Build signal docker image
|
- name: Build signal docker image
|
||||||
working-directory: signal
|
working-directory: signal
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/signal:latest .
|
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: Build relay binary
|
- name: Build relay binary
|
||||||
working-directory: relay
|
working-directory: relay
|
||||||
@@ -225,7 +225,7 @@ jobs:
|
|||||||
- name: Build relay docker image
|
- name: Build relay docker image
|
||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: |
|
run: |
|
||||||
docker build -t netbirdio/relay:latest .
|
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||||
|
|
||||||
- name: run docker compose up
|
- name: run docker compose up
|
||||||
working-directory: infrastructure_files/artifacts
|
working-directory: infrastructure_files/artifacts
|
||||||
@@ -256,7 +256,7 @@ jobs:
|
|||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|||||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
|||||||
GOARCH: wasm
|
GOARCH: wasm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -44,11 +44,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ dockers_v2:
|
|||||||
- netbirdio/netbird
|
- netbirdio/netbird
|
||||||
- ghcr.io/netbirdio/netbird
|
- ghcr.io/netbirdio/netbird
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: client/Dockerfile
|
dockerfile: client/Dockerfile
|
||||||
extra_files:
|
extra_files:
|
||||||
@@ -295,7 +295,7 @@ dockers_v2:
|
|||||||
- netbirdio/relay
|
- netbirdio/relay
|
||||||
- ghcr.io/netbirdio/relay
|
- ghcr.io/netbirdio/relay
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: relay/Dockerfile
|
dockerfile: relay/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -317,7 +317,7 @@ dockers_v2:
|
|||||||
- netbirdio/signal
|
- netbirdio/signal
|
||||||
- ghcr.io/netbirdio/signal
|
- ghcr.io/netbirdio/signal
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: signal/Dockerfile
|
dockerfile: signal/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -339,7 +339,7 @@ dockers_v2:
|
|||||||
- netbirdio/management
|
- netbirdio/management
|
||||||
- ghcr.io/netbirdio/management
|
- ghcr.io/netbirdio/management
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: management/Dockerfile
|
dockerfile: management/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -361,7 +361,7 @@ dockers_v2:
|
|||||||
- netbirdio/upload
|
- netbirdio/upload
|
||||||
- ghcr.io/netbirdio/upload
|
- ghcr.io/netbirdio/upload
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: upload-server/Dockerfile
|
dockerfile: upload-server/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -383,7 +383,7 @@ dockers_v2:
|
|||||||
- netbirdio/netbird-server
|
- netbirdio/netbird-server
|
||||||
- ghcr.io/netbirdio/netbird-server
|
- ghcr.io/netbirdio/netbird-server
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: combined/Dockerfile
|
dockerfile: combined/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -405,7 +405,7 @@ dockers_v2:
|
|||||||
- netbirdio/reverse-proxy
|
- netbirdio/reverse-proxy
|
||||||
- ghcr.io/netbirdio/reverse-proxy
|
- ghcr.io/netbirdio/reverse-proxy
|
||||||
tags:
|
tags:
|
||||||
- "v{{ .Version }}"
|
- "{{ .Version }}"
|
||||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||||
dockerfile: proxy/Dockerfile
|
dockerfile: proxy/Dockerfile
|
||||||
platforms:
|
platforms:
|
||||||
@@ -462,9 +462,13 @@ checksum:
|
|||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
- glob: ./release_files/install.sh
|
- glob: ./release_files/install.sh
|
||||||
- glob: ./infrastructure_files/getting-started.sh
|
- glob: ./infrastructure_files/getting-started.sh
|
||||||
|
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||||
|
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||||
|
|
||||||
release:
|
release:
|
||||||
extra_files:
|
extra_files:
|
||||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||||
- glob: ./release_files/install.sh
|
- glob: ./release_files/install.sh
|
||||||
- glob: ./infrastructure_files/getting-started.sh
|
- glob: ./infrastructure_files/getting-started.sh
|
||||||
|
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||||
|
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||||
|
|||||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: lint lint-all lint-install setup-hooks
|
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
# Install golangci-lint locally if needed
|
# Install golangci-lint locally if needed
|
||||||
@@ -25,3 +25,15 @@ setup-hooks:
|
|||||||
@git config core.hooksPath .githooks
|
@git config core.hooksPath .githooks
|
||||||
@chmod +x .githooks/pre-push
|
@chmod +x .githooks/pre-push
|
||||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
|
|
||||||
|
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||||
|
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||||
|
test-unit:
|
||||||
|
@go test -tags devcert -timeout 10m ./...
|
||||||
|
|
||||||
|
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||||
|
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||||
|
# Narrow the run with env vars, e.g.:
|
||||||
|
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||||
|
test-privileged:
|
||||||
|
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||||
|
|||||||
@@ -37,6 +37,11 @@
|
|||||||
</strong>
|
</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
> ### 🤖 NetBird Agent Network (Beta)
|
||||||
|
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||||
|
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||||
|
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||||
|
|
||||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|||||||
39
agent-network/README.md
Normal file
39
agent-network/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# NetBird Agent Network
|
||||||
|
|
||||||
|
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||||
|
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||||
|
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||||
|
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||||
|
policy, with no API keys to leak.
|
||||||
|
|
||||||
|
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||||
|
> infrastructure.
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
Agent Network is built on two existing NetBird capabilities:
|
||||||
|
|
||||||
|
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||||
|
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||||
|
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||||
|
key server-side, forwards to the API or gateway, and records usage.
|
||||||
|
|
||||||
|
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||||
|
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||||
|
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||||
|
|
||||||
|
## Where the code lives
|
||||||
|
|
||||||
|
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||||
|
components:
|
||||||
|
|
||||||
|
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||||
|
and runs the per-request middleware pipeline.
|
||||||
|
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||||
|
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||||
|
and usage/access logs.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Full documentation, architecture, and quickstart:
|
||||||
|
**https://docs.netbird.io/agent-network**
|
||||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||||
ProfileName: activeProf.Name,
|
ProfileName: string(activeProf.ID),
|
||||||
Username: currUser.Username,
|
Username: currUser.Username,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
|
|||||||
Username: &username,
|
Username: &username,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
return "", fmt.Errorf("switch profile failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.ID(resp.Id), nil
|
return profilemanager.ID(resp.Id), nil
|
||||||
|
|||||||
@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
currUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get current user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
currUser, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get current user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||||
profileName := args[0]
|
profileName := args[0]
|
||||||
|
|
||||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
|
||||||
ProfileName: profileName,
|
|
||||||
Username: currUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add profile request: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||||
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||||
}
|
}
|
||||||
|
|
||||||
id := profilemanager.ID(resp.Id)
|
|
||||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
|
||||||
|
// and returns the new profile's ID. It is the single entry point for profile
|
||||||
|
// creation, shared by `netbird profile add` and the `netbird up --profile
|
||||||
|
// <name>` auto-create path.
|
||||||
|
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
|
||||||
|
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
|
||||||
|
ProfileName: profileName,
|
||||||
|
Username: username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("add profile failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return profilemanager.ID(resp.Id), nil
|
||||||
|
}
|
||||||
|
|||||||
196
client/cmd/service_privileged_test.go
Normal file
196
client/cmd/service_privileged_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/kardianos/service"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceStartTimeout = 10 * time.Second
|
||||||
|
serviceStopTimeout = 5 * time.Second
|
||||||
|
statusPollInterval = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||||
|
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer timeoutCancel()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(statusPollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||||
|
case <-ticker.C:
|
||||||
|
status, err := s.Status()
|
||||||
|
if err != nil {
|
||||||
|
// Continue polling on transient errors
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if status == expectedStatus {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceLifecycle tests the complete service lifecycle
|
||||||
|
func TestServiceLifecycle(t *testing.T) {
|
||||||
|
// TODO: Add support for Windows and macOS
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Getenv("CONTAINER") == "true" {
|
||||||
|
t.Skip("Skipping service lifecycle test in container environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
originalServiceName := serviceName
|
||||||
|
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||||
|
defer func() {
|
||||||
|
serviceName = originalServiceName
|
||||||
|
}()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||||
|
|
||||||
|
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service config: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the subtests already cleaned up, there's nothing to do.
|
||||||
|
if _, err := s.Status(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Stop(); err != nil {
|
||||||
|
t.Errorf("cleanup: stop service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Uninstall(); err != nil {
|
||||||
|
t.Errorf("cleanup: uninstall service: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("Install", func(t *testing.T) {
|
||||||
|
installCmd.SetContext(ctx)
|
||||||
|
err := installCmd.RunE(installCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
status, err := s.Status()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, service.StatusUnknown, status)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Start", func(t *testing.T) {
|
||||||
|
startCmd.SetContext(ctx)
|
||||||
|
err := startCmd.RunE(startCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Restart", func(t *testing.T) {
|
||||||
|
restartCmd.SetContext(ctx)
|
||||||
|
err := restartCmd.RunE(restartCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Reconfigure", func(t *testing.T) {
|
||||||
|
originalLogLevel := logLevel
|
||||||
|
logLevel = "debug"
|
||||||
|
defer func() {
|
||||||
|
logLevel = originalLogLevel
|
||||||
|
}()
|
||||||
|
|
||||||
|
reconfigureCmd.SetContext(ctx)
|
||||||
|
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stop", func(t *testing.T) {
|
||||||
|
stopCmd.SetContext(ctx)
|
||||||
|
err := stopCmd.RunE(stopCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, stopped)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Uninstall", func(t *testing.T) {
|
||||||
|
uninstallCmd.SetContext(ctx)
|
||||||
|
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.Status()
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
serviceStartTimeout = 10 * time.Second
|
|
||||||
serviceStopTimeout = 5 * time.Second
|
|
||||||
statusPollInterval = 500 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
|
||||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer timeoutCancel()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(statusPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
|
||||||
case <-ticker.C:
|
|
||||||
status, err := s.Status()
|
|
||||||
if err != nil {
|
|
||||||
// Continue polling on transient errors
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if status == expectedStatus {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceLifecycle tests the complete service lifecycle
|
|
||||||
func TestServiceLifecycle(t *testing.T) {
|
|
||||||
// TODO: Add support for Windows and macOS
|
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
|
||||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.Getenv("CONTAINER") == "true" {
|
|
||||||
t.Skip("Skipping service lifecycle test in container environment")
|
|
||||||
}
|
|
||||||
|
|
||||||
originalServiceName := serviceName
|
|
||||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
|
||||||
defer func() {
|
|
||||||
serviceName = originalServiceName
|
|
||||||
}()
|
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
|
||||||
logLevel = "info"
|
|
||||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
|
||||||
|
|
||||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service config: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the subtests already cleaned up, there's nothing to do.
|
|
||||||
if _, err := s.Status(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Stop(); err != nil {
|
|
||||||
t.Errorf("cleanup: stop service: %v", err)
|
|
||||||
}
|
|
||||||
if err := s.Uninstall(); err != nil {
|
|
||||||
t.Errorf("cleanup: uninstall service: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
t.Run("Install", func(t *testing.T) {
|
|
||||||
installCmd.SetContext(ctx)
|
|
||||||
err := installCmd.RunE(installCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
status, err := s.Status()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotEqual(t, service.StatusUnknown, status)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Start", func(t *testing.T) {
|
|
||||||
startCmd.SetContext(ctx)
|
|
||||||
err := startCmd.RunE(startCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Restart", func(t *testing.T) {
|
|
||||||
restartCmd.SetContext(ctx)
|
|
||||||
err := restartCmd.RunE(restartCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Reconfigure", func(t *testing.T) {
|
|
||||||
originalLogLevel := logLevel
|
|
||||||
logLevel = "debug"
|
|
||||||
defer func() {
|
|
||||||
logLevel = originalLogLevel
|
|
||||||
}()
|
|
||||||
|
|
||||||
reconfigureCmd.SetContext(ctx)
|
|
||||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Stop", func(t *testing.T) {
|
|
||||||
stopCmd.SetContext(ctx)
|
|
||||||
err := stopCmd.RunE(stopCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, stopped)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Uninstall", func(t *testing.T) {
|
|
||||||
uninstallCmd.SetContext(ctx)
|
|
||||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.Status()
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceEnvVars tests environment variable parsing
|
// TestServiceEnvVars tests environment variable parsing
|
||||||
func TestServiceEnvVars(t *testing.T) {
|
func TestServiceEnvVars(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pm := profilemanager.NewProfileManager()
|
// Resolve the active profile's display name via the daemon, which runs
|
||||||
var profName string
|
// as root and can read the per-user profile files. The local profile
|
||||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
// manager only knows the active profile ID, not its display name.
|
||||||
profName = activeProf.Name
|
profName := getActiveProfileName(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getActiveProfileName asks the daemon for the active profile's display
|
||||||
|
// name. The daemon runs as root and can read the per-user profile files to
|
||||||
|
// resolve the ID to its human-readable name. Returns an empty string on any
|
||||||
|
// error so status output degrades gracefully.
|
||||||
|
func getActiveProfileName(ctx context.Context) string {
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.GetProfileName()
|
||||||
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "idle", "connecting", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
|
|||||||
@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
var profileSwitched bool
|
var profileSwitched bool
|
||||||
// switch profile if provided
|
// switch profile if provided
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("switch profile: %v", err)
|
return fmt.Errorf("switch profile: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
|
||||||
return fmt.Errorf("switch profile: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
profileSwitched = true
|
profileSwitched = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// switchOrCreateProfile switches the active profile to the one identified by
|
||||||
|
// handle, creating it first when it does not exist yet. This restores the
|
||||||
|
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
|
||||||
|
// missing profile instead of failing.
|
||||||
|
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
|
||||||
|
resolvedID, err := switchProfile(ctx, handle, username)
|
||||||
|
if err != nil {
|
||||||
|
st, ok := gstatus.FromError(err)
|
||||||
|
if !ok || st.Code() != codes.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Don't fail immediately on a create error: a concurrent run may
|
||||||
|
// have created the profile between the NotFound above and this
|
||||||
|
// call, in which case the retried switch still succeeds. Only
|
||||||
|
// surface the create error if the switch also fails.
|
||||||
|
_, createErr := createProfile(ctx, handle, username)
|
||||||
|
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
|
||||||
|
if createErr != nil {
|
||||||
|
return fmt.Errorf("create profile: %w", createErr)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createProfile dials the daemon and creates a new profile with the given
|
||||||
|
// display name, returning its generated ID. Use addProfileOnDaemon directly
|
||||||
|
// when a daemon client is already available to reuse the connection.
|
||||||
|
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
//nolint
|
||||||
|
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||||
|
"If the daemon is not running please run: "+
|
||||||
|
"\nnetbird service install \nnetbird service start\n", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
|
||||||
|
}
|
||||||
|
|
||||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||||
// override the default profile filepath if provided
|
// override the default profile filepath if provided
|
||||||
if configPath != "" {
|
if configPath != "" {
|
||||||
|
|||||||
@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-startCtx.Done():
|
case <-startCtx.Done():
|
||||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||||
// signal stream while holding the engine mutex and only unblocks on
|
// run loop to tear the engine down, so this cancel() is no longer
|
||||||
// cancellation. Stopping first would deadlock on that mutex.
|
// required to break the deadlock and could be removed. It is kept as a
|
||||||
|
// defensive belt-and-suspenders: cancelling the parent context first
|
||||||
|
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||||
cancel()
|
cancel()
|
||||||
if stopErr := client.Stop(); stopErr != nil {
|
if stopErr := client.Stop(); stopErr != nil {
|
||||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux
|
//go:build !linux || !privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
|||||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
|
||||||
wgPort := 51850
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("192.168.0.56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
|
||||||
wgPort := 51851
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("fe80::56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||||
wgPort := 51852
|
wgPort := 51852
|
||||||
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51850
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51851
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||||
wgPort := 51856
|
wgPort := 51856
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
|
|||||||
|
|
||||||
type ConnectClient struct {
|
type ConnectClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
runCancel context.CancelFunc
|
||||||
|
runExited chan struct{}
|
||||||
|
runOnce sync.Once
|
||||||
|
runStarted atomic.Bool
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -70,8 +75,14 @@ func NewConnectClient(
|
|||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
|
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||||
|
// loop. runCancel is set once at construction, so Stop can call it without
|
||||||
|
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||||
|
runCtx, runCancel := context.WithCancel(ctx)
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: runCtx,
|
||||||
|
runCancel: runCancel,
|
||||||
|
runExited: make(chan struct{}),
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
engineMutex: sync.Mutex{},
|
engineMutex: sync.Mutex{},
|
||||||
@@ -135,6 +146,11 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||||
|
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||||
|
// the loop to finish (and skip the wait if the loop never ran).
|
||||||
|
c.runStarted.Store(true)
|
||||||
|
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -290,7 +306,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
c.runCancel()
|
||||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -410,14 +426,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.engine = nil
|
c.engine = nil
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
// todo: consider to remove this condition. Is not thread safe.
|
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||||
// We should always call Stop(), but we need to verify that it is idempotent
|
|
||||||
if engine.wgInterface != nil {
|
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
|
||||||
|
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
@@ -433,12 +445,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.statusRecorder.ClientStart()
|
c.statusRecorder.ClientStart()
|
||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
c.runCancel()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -516,11 +528,9 @@ func (c *ConnectClient) Status() StatusType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Stop() error {
|
func (c *ConnectClient) Stop() error {
|
||||||
engine := c.Engine()
|
c.runCancel()
|
||||||
if engine != nil {
|
if c.runStarted.Load() {
|
||||||
if err := engine.Stop(); err != nil {
|
<-c.runExited
|
||||||
return fmt.Errorf("stop engine: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains.
|
// Resolver caches critical NetBird infrastructure domains.
|
||||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||||
|
// guarded by mutex.
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question]*cachedRecord
|
records map[dns.Question]*cachedRecord
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// failedResolves records the last failed initial resolve per domain so a
|
||||||
|
// domain that never resolves isn't retried on every server-domains update
|
||||||
|
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||||
|
// to the current server-domains set.
|
||||||
|
failedResolves map[domain.Domain]time.Time
|
||||||
|
|
||||||
chain ChainResolver
|
chain ChainResolver
|
||||||
chainMaxPriority int
|
chainMaxPriority int
|
||||||
refreshGroup singleflight.Group
|
refreshGroup singleflight.Group
|
||||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
|||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question]*cachedRecord),
|
records: make(map[dns.Question]*cachedRecord),
|
||||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||||
cacheTTL: resolveCacheTTL(),
|
failedResolves: make(map[domain.Domain]time.Time),
|
||||||
|
cacheTTL: resolveCacheTTL(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||||
// entry for that qtype.
|
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||||
|
// the resolved family is still cached but AddDomain returns an error so the
|
||||||
|
// caller retries the incomplete resolve rather than treating it as complete.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
|
if errA != nil || errAAAA != nil {
|
||||||
|
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
delete(m.records, qAAAA)
|
delete(m.records, qAAAA)
|
||||||
delete(m.refreshing, qA)
|
delete(m.refreshing, qA)
|
||||||
delete(m.refreshing, qAAAA)
|
delete(m.refreshing, qAAAA)
|
||||||
|
delete(m.failedResolves, d)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
|||||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||||
currentDomains := m.GetCachedDomains()
|
currentDomains := m.GetCachedDomains()
|
||||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||||
|
m.pruneFailedResolves(allDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addNewDomains(ctx, newDomains)
|
m.addNewDomains(ctx, newDomains)
|
||||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
|||||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
// addNewDomains resolves and caches all domains from the update
|
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||||
|
// running the lookups concurrently. Domains already cached are skipped and left
|
||||||
|
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||||
|
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||||
|
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||||
|
// sync lock on every update.
|
||||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||||
for _, newDomain := range newDomains {
|
for _, newDomain := range newDomains {
|
||||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
if _, dup := seen[newDomain]; dup {
|
||||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
continue
|
||||||
} else {
|
}
|
||||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
seen[newDomain] = struct{}{}
|
||||||
|
|
||||||
|
if !m.needsResolve(newDomain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(d domain.Domain) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
m.markResolveFailed(d)
|
||||||
|
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.clearResolveFailed(d)
|
||||||
|
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||||
|
}(newDomain)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||||
|
// incomplete resolve gates retries on the backoff even when one family is
|
||||||
|
// already cached, so a transiently-failed family is retried instead of being
|
||||||
|
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||||
|
// to the stale-while-revalidate refresh path.
|
||||||
|
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||||
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if failedAt, ok := m.failedResolves[d]; ok {
|
||||||
|
return time.Since(failedAt) >= refreshBackoff
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||||
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||||
|
if _, ok := m.records[q]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.failedResolves[d] = time.Now()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
delete(m.failedResolves, d)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||||
|
// the server-domains set, keeping the map bounded to the current set (a
|
||||||
|
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||||
|
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
for d := range m.failedResolves {
|
||||||
|
if !slices.Contains(domains, d) {
|
||||||
|
delete(m.failedResolves, d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
calls map[string]int
|
calls map[string]int
|
||||||
answers map[string][]dns.RR
|
answers map[string][]dns.RR
|
||||||
|
qErr map[string]error
|
||||||
err error
|
err error
|
||||||
hasRoot bool
|
hasRoot bool
|
||||||
onLookup func()
|
onLookup func()
|
||||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
|||||||
return &fakeChain{
|
return &fakeChain{
|
||||||
calls: map[string]int{},
|
calls: map[string]int{},
|
||||||
answers: map[string][]dns.RR{},
|
answers: map[string][]dns.RR{},
|
||||||
|
qErr: map[string]error{},
|
||||||
hasRoot: true,
|
hasRoot: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
|||||||
f.calls[key]++
|
f.calls[key]++
|
||||||
answers := f.answers[key]
|
answers := f.answers[key]
|
||||||
err := f.err
|
err := f.err
|
||||||
|
if err == nil {
|
||||||
|
err = f.qErr[key]
|
||||||
|
}
|
||||||
onLookup := f.onLookup
|
onLookup := f.onLookup
|
||||||
f.mu.Unlock()
|
f.mu.Unlock()
|
||||||
|
|
||||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|||||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||||
|
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||||
|
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"first update must resolve the domain")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"cached domain must not be re-resolved on a subsequent update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// New domains in a single update must resolve concurrently rather than serially.
|
||||||
|
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
|
||||||
|
var inflight, maxInflight atomic.Int32
|
||||||
|
chain.onLookup = func() {
|
||||||
|
n := inflight.Add(1)
|
||||||
|
for {
|
||||||
|
old := maxInflight.Load()
|
||||||
|
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
inflight.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||||
|
for _, d := range relays {
|
||||||
|
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||||
|
}
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||||
|
require.NoError(t, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||||
|
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||||
|
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A domain that fails to resolve must not be retried on every update; the
|
||||||
|
// failure backoff suppresses re-resolution until it expires.
|
||||||
|
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("resolve boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"first update must attempt the resolve")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||||
|
"failed resolve must back off and not retry on the next update")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||||
|
// the same host) must be resolved once per update, not once per occurrence.
|
||||||
|
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
sd := dnsconfig.ServerDomains{
|
||||||
|
Stuns: []domain.Domain{"dup.example.com"},
|
||||||
|
Turns: []domain.Domain{"dup.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||||
|
"a domain appearing under multiple server-domain types must resolve once")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||||
|
// so the map stays bounded to the current set.
|
||||||
|
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("resolve boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.True(t, marked, "failed resolve must be recorded")
|
||||||
|
|
||||||
|
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||||
|
}
|
||||||
|
|
||||||
|
// When one family hard-errors while the other resolves, the domain is cached
|
||||||
|
// for the working family but recorded as incomplete so the failed family is
|
||||||
|
// retried under backoff instead of being treated as fully resolved forever.
|
||||||
|
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||||
|
d := domain.Domain("relay.example.com")
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||||
|
_, marked := r.failedResolves[d]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.True(t, aCached, "the working family must still be cached")
|
||||||
|
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||||
|
|
||||||
|
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||||
|
r.mutex.Unlock()
|
||||||
|
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||||
|
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||||
|
// re-resolved on every sync.
|
||||||
|
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||||
|
d := domain.Domain("v4only.example.com")
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, marked := r.failedResolves[d]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||||
|
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||||
|
}
|
||||||
@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
|
|||||||
}
|
}
|
||||||
return "[" + strings.Join(parts, ", ") + "]"
|
return "[" + strings.Join(parts, ", ") + "]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||||
|
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||||
|
// advertise EDNS0.
|
||||||
|
func StripOPT(msg *dns.Msg) {
|
||||||
|
if len(msg.Extra) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := msg.Extra[:0]
|
||||||
|
for _, rr := range msg.Extra {
|
||||||
|
if _, ok := rr.(*dns.OPT); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, rr)
|
||||||
|
}
|
||||||
|
msg.Extra = out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||||
|
// the message, if present.
|
||||||
|
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||||
|
opt := msg.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
for _, o := range opt.Option {
|
||||||
|
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||||
|
return ede, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|||||||
@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripOPT(t *testing.T) {
|
||||||
|
rm := &dns.Msg{
|
||||||
|
Extra: []dns.RR{
|
||||||
|
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||||
|
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
StripOPT(rm)
|
||||||
|
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||||
|
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||||
|
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractEDE(t *testing.T) {
|
||||||
|
t.Run("no edns", func(t *testing.T) {
|
||||||
|
_, ok := ExtractEDE(&dns.Msg{})
|
||||||
|
assert.False(t, ok, "message without OPT has no EDE")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("edns without ede", func(t *testing.T) {
|
||||||
|
rm := &dns.Msg{}
|
||||||
|
rm.SetEdns0(4096, false)
|
||||||
|
_, ok := ExtractEDE(rm)
|
||||||
|
assert.False(t, ok, "OPT without EDE option returns false")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with ede", func(t *testing.T) {
|
||||||
|
rm := &dns.Msg{}
|
||||||
|
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||||
|
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||||
|
rm.Extra = append(rm.Extra, opt)
|
||||||
|
|
||||||
|
ede, ok := ExtractEDE(rm)
|
||||||
|
assert.True(t, ok, "EDE option should be found")
|
||||||
|
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||||
|
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -38,11 +39,15 @@ const (
|
|||||||
// defaultWarningDelayBase is the starting grace window before a
|
// defaultWarningDelayBase is the starting grace window before a
|
||||||
// "Nameserver group unreachable" event fires for a group that's
|
// "Nameserver group unreachable" event fires for a group that's
|
||||||
// never been healthy and only has overlay upstreams with no
|
// never been healthy and only has overlay upstreams with no
|
||||||
// Connected peer. Per-server and overridable; see warningDelayFor.
|
// Connected peer. Per-server and overridable via envWarningDelay;
|
||||||
defaultWarningDelayBase = 30 * time.Second
|
// see warningDelay.
|
||||||
|
defaultWarningDelayBase = 60 * time.Second
|
||||||
// warningDelayBonusCap caps the route-count bonus added to the
|
// warningDelayBonusCap caps the route-count bonus added to the
|
||||||
// base grace window. See warningDelayFor.
|
// base grace window. See warningDelay.
|
||||||
warningDelayBonusCap = 30 * time.Second
|
warningDelayBonusCap = 30 * time.Second
|
||||||
|
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
|
||||||
|
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
|
||||||
|
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
|
||||||
)
|
)
|
||||||
|
|
||||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||||
@@ -135,7 +140,7 @@ type DefaultServer struct {
|
|||||||
disableSys bool
|
disableSys bool
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxHandlers []handlerWrapper
|
||||||
localResolver *local.Resolver
|
localResolver *local.Resolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
@@ -199,8 +204,6 @@ type handlerWrapper struct {
|
|||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
|
||||||
|
|
||||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||||
type DefaultServerConfig struct {
|
type DefaultServerConfig struct {
|
||||||
WgInterface WGIface
|
WgInterface WGIface
|
||||||
@@ -289,7 +292,6 @@ func newDefaultServer(
|
|||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: handlerChain,
|
handlerChain: handlerChain,
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
@@ -298,7 +300,7 @@ func newDefaultServer(
|
|||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
mgmtCacheResolver: mgmtCacheResolver,
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||||
warningDelayBase: defaultWarningDelayBase,
|
warningDelayBase: warningDelayBaseFromEnv(),
|
||||||
healthRefresh: make(chan struct{}, 1),
|
healthRefresh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
// Wire the local resolver against the peer status recorder so it can
|
// Wire the local resolver against the peer status recorder so it can
|
||||||
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
|||||||
type routeSettable interface {
|
type routeSettable interface {
|
||||||
setSelectedRoutes(func() route.HAMap)
|
setSelectedRoutes(func() route.HAMap)
|
||||||
}
|
}
|
||||||
for _, entry := range s.dnsMuxMap {
|
for _, entry := range s.dnsMuxHandlers {
|
||||||
if h, ok := entry.handler.(routeSettable); ok {
|
if h, ok := entry.handler.(routeSettable); ok {
|
||||||
h.setSelectedRoutes(selected)
|
h.setSelectedRoutes(selected)
|
||||||
}
|
}
|
||||||
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
|
|||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
for _, existing := range s.dnsMuxMap {
|
for _, existing := range s.dnsMuxHandlers {
|
||||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
existing.handler.Stop()
|
// The local resolver is a persistent singleton shared by every custom
|
||||||
|
// zone and reused across config updates. Its chain registrations are
|
||||||
|
// per-config and must be deregistered, but Stop() cancels its lookup
|
||||||
|
// context (breaking external CNAME-target resolution) and clears its
|
||||||
|
// records, so it must not be torn down here.
|
||||||
|
if existing.handler != s.localResolver {
|
||||||
|
existing.handler.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.handler.ID()] = update
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxHandlers = muxUpdates
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateNSGroupStates records the new group set and pokes the refresher.
|
// updateNSGroupStates records the new group set and pokes the refresher.
|
||||||
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// warningDelayBaseFromEnv returns the base grace window, honoring
|
||||||
|
// envWarningDelay when it holds a valid positive Go duration. Invalid or
|
||||||
|
// non-positive values fall back to defaultWarningDelayBase.
|
||||||
|
func warningDelayBaseFromEnv() time.Duration {
|
||||||
|
val := os.Getenv(envWarningDelay)
|
||||||
|
if val == "" {
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
d, err := time.ParseDuration(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
|
||||||
|
return defaultWarningDelayBase
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
// warningDelay returns the grace window for the given selected-route
|
// warningDelay returns the grace window for the given selected-route
|
||||||
// count. Scales gently: +1s per 100 routes, capped by
|
// count. Scales gently: +1s per 100 routes, capped by
|
||||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||||
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
|||||||
// in more than one handler.
|
// in more than one handler.
|
||||||
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||||
merged := make(map[netip.AddrPort]UpstreamHealth)
|
merged := make(map[netip.AddrPort]UpstreamHealth)
|
||||||
for _, entry := range s.dnsMuxMap {
|
for _, entry := range s.dnsMuxHandlers {
|
||||||
reporter, ok := entry.handler.(upstreamHealthReporter)
|
reporter, ok := entry.handler.(upstreamHealthReporter)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
|
|||||||
485
client/internal/dns/server_privileged_test.go
Normal file
485
client/internal/dns/server_privileged_test.go
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initUpstreamMap []handlerWrapper
|
||||||
|
initLocalZones []nbdns.CustomZone
|
||||||
|
initSerial uint64
|
||||||
|
inputSerial uint64
|
||||||
|
inputUpdate nbdns.Config
|
||||||
|
shouldFail bool
|
||||||
|
expectedUpstreamMap []handlerWrapper
|
||||||
|
expectedLocalQs []dns.Question
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Initial Config Should Succeed",
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New Config Should Succeed",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 2,
|
||||||
|
inputSerial: 1,
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{{
|
||||||
|
domain: ".",
|
||||||
|
priority: PriorityDefault,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disabled Service Should clean map",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||||
|
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wgIface.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = dnsServer.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||||
|
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||||
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
|
if err != nil {
|
||||||
|
if testCase.shouldFail {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||||
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range testCase.expectedUpstreamMap {
|
||||||
|
found := false
|
||||||
|
for _, got := range dnsServer.dnsMuxHandlers {
|
||||||
|
if got.domain == expected.domain && got.priority == expected.priority {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, q := range testCase.expectedLocalQs {
|
||||||
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||||
|
Question: []dns.Question{q},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testCase.expectedLocalQs) > 0 {
|
||||||
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||||
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create stdnet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: "utun2301",
|
||||||
|
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create and init wireguard interface: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = wgIface.Close(); err != nil {
|
||||||
|
t.Logf("close wireguard interface: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("run DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||||
|
t.Logf("restore DNS settings on the host: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &local.Resolver{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||||
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server with regular configuration
|
||||||
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := update
|
||||||
|
update2.ServiceEnable = false
|
||||||
|
// Disable the server, stop the listener
|
||||||
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update3 := update2
|
||||||
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||||
|
// But service still get updates and we checking that we handle
|
||||||
|
// internal state in the right way
|
||||||
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -23,7 +22,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -104,481 +102,6 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
|
||||||
var srvs []netip.AddrPort
|
|
||||||
for _, srv := range servers {
|
|
||||||
srvs = append(srvs, srv.AddrPort())
|
|
||||||
}
|
|
||||||
u := &upstreamResolverBase{
|
|
||||||
domain: domain.Domain(d),
|
|
||||||
cancel: func() {},
|
|
||||||
}
|
|
||||||
u.addRace(srvs)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
dummyHandler := local.NewResolver()
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
initUpstreamMap registeredHandlerMap
|
|
||||||
initLocalZones []nbdns.CustomZone
|
|
||||||
initSerial uint64
|
|
||||||
inputSerial uint64
|
|
||||||
inputUpdate nbdns.Config
|
|
||||||
shouldFail bool
|
|
||||||
expectedUpstreamMap registeredHandlerMap
|
|
||||||
expectedLocalQs []dns.Question
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Initial Config Should Succeed",
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
|
||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
|
||||||
domain: "netbird.io",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
dummyHandler.ID(): handlerWrapper{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityDefault,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "New Config Should Succeed",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: registeredHandlerMap{
|
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
|
||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
|
||||||
domain: "netbird.io",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
"local-resolver": handlerWrapper{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
|
||||||
initSerial: 2,
|
|
||||||
inputSerial: 1,
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
|
||||||
domain: ".",
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityDefault,
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: registeredHandlerMap{
|
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Disabled Service Should clean map",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: registeredHandlerMap{
|
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: dummyHandler,
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
|
||||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = wgIface.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = dnsServer.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
|
||||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
|
||||||
if err != nil {
|
|
||||||
if testCase.shouldFail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
|
||||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
|
||||||
}
|
|
||||||
|
|
||||||
for key := range testCase.expectedUpstreamMap {
|
|
||||||
_, found := dnsServer.dnsMuxMap[key]
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
|
||||||
responseWriter := &test.MockResponseWriter{
|
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
responseMSG = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, q := range testCase.expectedLocalQs {
|
|
||||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
|
||||||
Question: []dns.Question{q},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(testCase.expectedLocalQs) > 0 {
|
|
||||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
|
||||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create stdnet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: "utun2301",
|
|
||||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("build interface wireguard: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create and init wireguard interface: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = wgIface.Close(); err != nil {
|
|
||||||
t.Logf("close wireguard interface: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
||||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
||||||
t.Errorf("set packet filter: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("run DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
||||||
t.Logf("restore DNS settings on the host: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
|
||||||
"id1": handlerWrapper{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &local.Resolver{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
|
||||||
dnsServer.updateSerial = 0
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server with regular configuration
|
|
||||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update2 := update
|
|
||||||
update2.ServiceEnable = false
|
|
||||||
// Disable the server, stop the listener
|
|
||||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update3 := update2
|
|
||||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
||||||
// But service still get updates and we checking that we handle
|
|
||||||
// internal state in the right way
|
|
||||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1029,15 +552,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
|
|||||||
func (m *mockService) DeregisterMux(string) {}
|
func (m *mockService) DeregisterMux(string) {}
|
||||||
|
|
||||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||||
baseMatchHandlers := registeredHandlerMap{
|
baseMatchHandlers := []handlerWrapper{
|
||||||
"upstream-group1": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
@@ -1046,15 +569,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
baseRootHandlers := registeredHandlerMap{
|
baseRootHandlers := []handlerWrapper{
|
||||||
"upstream-root1": {
|
{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-root1",
|
Id: "upstream-root1",
|
||||||
},
|
},
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
"upstream-root2": {
|
{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-root2",
|
Id: "upstream-root2",
|
||||||
@@ -1063,22 +586,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
baseMixedHandlers := registeredHandlerMap{
|
baseMixedHandlers := []handlerWrapper{
|
||||||
"upstream-group1": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityUpstream - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
"upstream-other": {
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
@@ -1089,7 +612,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialHandlers registeredHandlerMap
|
initialHandlers []handlerWrapper
|
||||||
updates []handlerWrapper
|
updates []handlerWrapper
|
||||||
expectedHandlers map[string]string // map[HandlerID]domain
|
expectedHandlers map[string]string // map[HandlerID]domain
|
||||||
description string
|
description string
|
||||||
@@ -1373,32 +896,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
server := &DefaultServer{
|
server := &DefaultServer{
|
||||||
dnsMuxMap: tt.initialHandlers,
|
dnsMuxHandlers: tt.initialHandlers,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
service: &mockService{},
|
service: &mockService{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform the update
|
// Perform the update
|
||||||
server.updateMux(tt.updates)
|
server.updateMux(tt.updates)
|
||||||
|
|
||||||
// Verify the results
|
// Verify the results
|
||||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
|
||||||
"Number of handlers after update doesn't match expected")
|
"Number of handlers after update doesn't match expected")
|
||||||
|
|
||||||
// Check each expected handler
|
// Check each expected handler
|
||||||
for id, expectedDomain := range tt.expectedHandlers {
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
var found *handlerWrapper
|
||||||
assert.True(t, exists, "Expected handler %s not found", id)
|
for i := range server.dnsMuxHandlers {
|
||||||
if exists {
|
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
|
||||||
assert.Equal(t, expectedDomain, handler.domain,
|
found = &server.dnsMuxHandlers[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotNil(t, found, "Expected handler %s not found", id)
|
||||||
|
if found != nil {
|
||||||
|
assert.Equal(t, expectedDomain, found.domain,
|
||||||
"Domain mismatch for handler %s", id)
|
"Domain mismatch for handler %s", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify no unexpected handlers exist
|
// Verify no unexpected handlers exist
|
||||||
for HandlerID := range server.dnsMuxMap {
|
for _, entry := range server.dnsMuxHandlers {
|
||||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
_, expected := tt.expectedHandlers[string(entry.handler.ID())]
|
||||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the handlerChain state and order
|
// Verify the handlerChain state and order
|
||||||
@@ -1413,7 +942,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
// Verify handler exists in mux
|
// Verify handler exists in mux
|
||||||
foundInMux := false
|
foundInMux := false
|
||||||
for _, muxEntry := range server.dnsMuxMap {
|
for _, muxEntry := range server.dnsMuxHandlers {
|
||||||
if chainEntry.Handler == muxEntry.handler &&
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
chainEntry.Priority == muxEntry.priority &&
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||||
@@ -1422,12 +951,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, foundInMux,
|
assert.True(t, foundInMux,
|
||||||
"Handler in chain not found in dnsMuxMap")
|
"Handler in chain not found in dnsMuxHandlers")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chainHasPattern reports whether the handler chain holds an entry registered
|
||||||
|
// for the given fqdn pattern at the given priority.
|
||||||
|
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
|
||||||
|
for _, h := range s.handlerChain.handlers {
|
||||||
|
if h.OrigPattern == pattern && h.Priority == priority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
|
||||||
|
// tracks each (handler, domain) registration independently when one handler
|
||||||
|
// serves multiple zones. Every custom zone is served by the same handler
|
||||||
|
// instance (the local resolver, whose ID is the constant "local-resolver"), so
|
||||||
|
// removing one zone must deregister exactly that zone's chain entry and leave
|
||||||
|
// the others in place. Tracking registrations by handler ID alone collapses all
|
||||||
|
// zones onto one entry, leaving removed zones in the chain to answer
|
||||||
|
// authoritatively with no records.
|
||||||
|
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
|
||||||
|
// One handler serves every custom zone, mirroring s.localResolver.
|
||||||
|
shared := &mockHandler{Id: "local-resolver"}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two custom zones under the same handler. The surviving zone is registered
|
||||||
|
// last, mirroring the management emission order.
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||||
|
"userzone.test should be registered after the first update")
|
||||||
|
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||||
|
"peerzone.test should be registered after the first update")
|
||||||
|
|
||||||
|
// Remove one zone, keep the other.
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
|
||||||
|
"peerzone.test should remain after removing userzone.test")
|
||||||
|
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
|
||||||
|
"userzone.test handler must be deregistered, not leaked in the chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
|
||||||
|
// does not tear down the shared local resolver during reconfiguration. The
|
||||||
|
// resolver is a process-lifetime singleton reused across config updates;
|
||||||
|
// Stop() cancels its lookup context (breaking external CNAME-target
|
||||||
|
// resolution) and clears its records. updateMux must deregister its chain
|
||||||
|
// entries without stopping it. Records surviving a teardown update is the
|
||||||
|
// observable proxy: Stop() would have cleared them.
|
||||||
|
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
|
||||||
|
resolver := local.NewResolver()
|
||||||
|
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
|
||||||
|
Name: "peer.netbird.cloud.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "10.0.0.1",
|
||||||
|
}))
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
localResolver: resolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
server.updateMux([]handlerWrapper{
|
||||||
|
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Remove the zone. The resolver must survive so its records and lookup
|
||||||
|
// context stay intact for the next registration.
|
||||||
|
server.updateMux(nil)
|
||||||
|
|
||||||
|
var response *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
response = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
|
||||||
|
|
||||||
|
require.NotNil(t, response, "local resolver should answer after teardown")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
|
||||||
|
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
|
||||||
|
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
|
||||||
|
}
|
||||||
|
|
||||||
func TestExtraDomains(t *testing.T) {
|
func TestExtraDomains(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -2049,7 +1674,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
|||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
groups := []*nbdns.NameServerGroup{
|
groups := []*nbdns.NameServerGroup{
|
||||||
@@ -2207,7 +1831,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
|
||||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||||
// without spinning up real handlers.
|
// without spinning up real handlers.
|
||||||
type healthStubHandler struct {
|
type healthStubHandler struct {
|
||||||
@@ -2283,12 +1907,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||||
activeRoutes: func() route.HAMap { return fx.active },
|
activeRoutes: func() route.HAMap { return fx.active },
|
||||||
warningDelayBase: defaultWarningDelayBase,
|
warningDelayBase: defaultWarningDelayBase,
|
||||||
}
|
}
|
||||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
fx.server.mux.Lock()
|
fx.server.mux.Lock()
|
||||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||||
@@ -2395,7 +2018,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return nil },
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
warningDelayBase: 50 * time.Millisecond,
|
warningDelayBase: 50 * time.Millisecond,
|
||||||
@@ -2407,7 +2029,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
|||||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
}}
|
}}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2444,7 +2066,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
service: NewServiceViaMemory(wgIface),
|
service: NewServiceViaMemory(wgIface),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
extraDomains: map[domain.Domain]int{},
|
extraDomains: map[domain.Domain]int{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
statusRecorder: peer.NewRecorder("mgm"),
|
statusRecorder: peer.NewRecorder("mgm"),
|
||||||
selectedRoutes: func() route.HAMap { return nil },
|
selectedRoutes: func() route.HAMap { return nil },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
@@ -2459,7 +2080,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||||
}
|
}
|
||||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2484,6 +2105,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
|
|||||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||||
// comes up and a query succeeds before the grace window elapses. No
|
// comes up and a query succeeds before the grace window elapses. No
|
||||||
// warning should ever have fired, and no recovery either.
|
// warning should ever have fired, and no recovery either.
|
||||||
|
func TestWarningDelayBaseFromEnv(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
set bool
|
||||||
|
val string
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
|
||||||
|
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
|
||||||
|
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
|
||||||
|
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
|
||||||
|
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
|
||||||
|
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Setenv(envWarningDelay, tc.val)
|
||||||
|
if !tc.set {
|
||||||
|
os.Unsetenv(envWarningDelay)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||||
fx := newProjTestFixture(t)
|
fx := newProjTestFixture(t)
|
||||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||||
@@ -2595,7 +2242,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
|||||||
server := &DefaultServer{
|
server := &DefaultServer{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
statusRecorder: recorder,
|
statusRecorder: recorder,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||||
activeRoutes: func() route.HAMap { return nil },
|
activeRoutes: func() route.HAMap { return nil },
|
||||||
warningDelayBase: time.Hour,
|
warningDelayBase: time.Hour,
|
||||||
@@ -2613,7 +2259,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
|||||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
|
||||||
|
|
||||||
server.mux.Lock()
|
server.mux.Lock()
|
||||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||||
@@ -2640,7 +2286,6 @@ func TestDNSLoopPrevention(t *testing.T) {
|
|||||||
localResolver: local.NewResolver(),
|
localResolver: local.NewResolver(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -443,29 +443,32 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
|||||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A valid response means the upstream is reachable, whatever the Rcode.
|
||||||
|
u.markUpstreamOk(upstream)
|
||||||
|
|
||||||
proto := ""
|
proto := ""
|
||||||
if upstreamProto != nil {
|
if upstreamProto != nil {
|
||||||
proto = upstreamProto.protocol
|
proto = upstreamProto.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||||
|
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
|
||||||
|
// refused zones, transient recursion errors), not reachability
|
||||||
|
// problems: fail over for a better answer but keep the upstream healthy.
|
||||||
if code, ok := nonRetryableEDE(rm); ok {
|
if code, ok := nonRetryableEDE(rm); ok {
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
resutil.StripOPT(rm)
|
||||||
}
|
}
|
||||||
u.markUpstreamOk(upstream)
|
|
||||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||||
}
|
}
|
||||||
reason := dns.RcodeToString[rm.Rcode]
|
reason := dns.RcodeToString[rm.Rcode]
|
||||||
u.markUpstreamFail(upstream, reason)
|
|
||||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hadEdns {
|
if !hadEdns {
|
||||||
stripOPT(rm)
|
resutil.StripOPT(rm)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.markUpstreamOk(upstream)
|
|
||||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,22 +523,6 @@ func upstreamUDPSize() uint16 {
|
|||||||
return dns.MinMsgSize
|
return dns.MinMsgSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
|
||||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
|
||||||
func stripOPT(rm *dns.Msg) {
|
|
||||||
if len(rm.Extra) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out := rm.Extra[:0]
|
|
||||||
for _, rr := range rm.Extra {
|
|
||||||
if _, ok := rr.(*dns.OPT); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, rr)
|
|
||||||
}
|
|
||||||
rm.Extra = out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||||
|
|||||||
@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
|||||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
|
||||||
|
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
|
||||||
|
// those are per-question outcomes from a reachable server and must not mark
|
||||||
|
// the upstream unhealthy. Only transport failures (timeouts) do.
|
||||||
|
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
|
||||||
|
a := netip.MustParseAddrPort("192.0.2.10:53")
|
||||||
|
b := netip.MustParseAddrPort("192.0.2.11:53")
|
||||||
|
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
respA mockUpstreamResponse
|
||||||
|
respB mockUpstreamResponse
|
||||||
|
wantHealthy bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "both SERVFAIL are reachable",
|
||||||
|
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||||
|
wantHealthy: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both REFUSED are reachable",
|
||||||
|
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
|
||||||
|
wantHealthy: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout marks unhealthy",
|
||||||
|
respA: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
respB: mockUpstreamResponse{err: timeoutErr},
|
||||||
|
wantHealthy: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mockClient := &mockUpstreamResolverPerServer{
|
||||||
|
responses: map[string]mockUpstreamResponse{
|
||||||
|
a.String(): tc.respA,
|
||||||
|
b.String(): tc.respB,
|
||||||
|
},
|
||||||
|
rtt: time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := &upstreamResolverBase{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: mockClient,
|
||||||
|
upstreamTimeout: UpstreamTimeout,
|
||||||
|
}
|
||||||
|
resolver.addRace([]netip.AddrPort{a, b})
|
||||||
|
|
||||||
|
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||||
|
|
||||||
|
health := resolver.UpstreamHealth()
|
||||||
|
require.Contains(t, health, a, "primary upstream should have a health record")
|
||||||
|
if tc.wantHealthy {
|
||||||
|
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
|
||||||
|
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
|
||||||
|
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
|
||||||
|
} else {
|
||||||
|
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
|
||||||
|
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFormatFailures(t *testing.T) {
|
func TestFormatFailures(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -913,19 +985,6 @@ func TestEDEName(t *testing.T) {
|
|||||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStripOPT(t *testing.T) {
|
|
||||||
rm := &dns.Msg{
|
|
||||||
Extra: []dns.RR{
|
|
||||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
|
||||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
stripOPT(rm)
|
|
||||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
|
||||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
|
||||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||||
|
|||||||
@@ -26,6 +26,15 @@ import (
|
|||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
const upstreamTimeout = 15 * time.Second
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||||
|
// client can see the reason without inspecting this peer's logs. They live in
|
||||||
|
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||||
|
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||||
|
const (
|
||||||
|
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||||
|
edeNetbirdUpstreamFailure uint16 = 49153
|
||||||
|
)
|
||||||
|
|
||||||
type resolver interface {
|
type resolver interface {
|
||||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
}
|
}
|
||||||
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
|||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
reqHasEdns bool,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reqHasEdns {
|
||||||
|
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||||
|
}
|
||||||
|
|
||||||
f.writeResponse(logger, w, resp, domain, startTime)
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
|||||||
|
|
||||||
return selectedResId, matches
|
return selectedResId, matches
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||||
|
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||||
|
if dnsErr != nil && dnsErr.IsTimeout {
|
||||||
|
return edeNetbirdUpstreamTimeout
|
||||||
|
}
|
||||||
|
return edeNetbirdUpstreamFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||||
|
// It deliberately omits the upstream server address, which may be an internal
|
||||||
|
// resolver and is exposed to any client permitted to use the route; the full
|
||||||
|
// detail stays in the forwarder's local log.
|
||||||
|
func edeText(dnsErr *net.DNSError) string {
|
||||||
|
if dnsErr != nil && dnsErr.IsTimeout {
|
||||||
|
return "netbird forwarder: upstream timeout"
|
||||||
|
}
|
||||||
|
return "netbird forwarder: upstream failure"
|
||||||
|
}
|
||||||
|
|
||||||
|
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||||
|
// creating the OPT pseudo-record if the response does not already carry one.
|
||||||
|
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||||
|
opt := resp.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
opt = resp.IsEdns0()
|
||||||
|
}
|
||||||
|
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
lookupErr error
|
||||||
|
reqEdns bool
|
||||||
|
wantEDE bool
|
||||||
|
wantCode uint16
|
||||||
|
wantTextHas string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "timeout with edns0",
|
||||||
|
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||||
|
reqEdns: true,
|
||||||
|
wantEDE: true,
|
||||||
|
wantCode: edeNetbirdUpstreamTimeout,
|
||||||
|
wantTextHas: "netbird forwarder: upstream timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server failure with edns0",
|
||||||
|
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||||
|
reqEdns: true,
|
||||||
|
wantEDE: true,
|
||||||
|
wantCode: edeNetbirdUpstreamFailure,
|
||||||
|
wantTextHas: "netbird forwarder: upstream failure",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no edns0 in request omits ede",
|
||||||
|
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||||
|
reqEdns: false,
|
||||||
|
wantEDE: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||||
|
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
if tt.reqEdns {
|
||||||
|
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "expected a response")
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||||
|
|
||||||
|
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||||
|
if !tt.wantEDE {
|
||||||
|
assert.False(t, ok, "response must not carry EDE")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.True(t, ok, "response must carry EDE")
|
||||||
|
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||||
|
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||||
|
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
// Test that large UDP responses are truncated with TC bit set
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
mockResolver := &MockResolver{}
|
mockResolver := &MockResolver{}
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ const (
|
|||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
|
|
||||||
|
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||||
|
|
||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
@@ -199,6 +201,8 @@ type Engine struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
started bool
|
||||||
|
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
|
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
@@ -206,6 +210,12 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
// forwardingRules holds the ingress forward rules applied for the current target.
|
||||||
|
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
|
||||||
|
// it is stashed here so the final, peer-converged pass can build the lazy-connection
|
||||||
|
// exclude list without recomputing them on every bounded peer pass.
|
||||||
|
forwardingRules []firewallManager.ForwardRule
|
||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
@@ -279,9 +289,15 @@ func NewEngine(
|
|||||||
services EngineServices,
|
services EngineServices,
|
||||||
mobileDep MobileDependency,
|
mobileDep MobileDependency,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
// The engine is single-use: a fresh instance is built per connection
|
||||||
|
// cycle (see Client.run), so the run context is created once here rather
|
||||||
|
// than in Start.
|
||||||
|
ctx, cancel := context.WithCancel(clientCtx)
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
signal: services.SignalClient,
|
signal: services.SignalClient,
|
||||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||||
mgmClient: services.MgmClient,
|
mgmClient: services.MgmClient,
|
||||||
@@ -314,8 +330,34 @@ func (e *Engine) Stop() error {
|
|||||||
log.Debugf("tried stopping engine that is nil")
|
log.Debugf("tried stopping engine that is nil")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
e.cancel()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
|
|
||||||
|
e.stopLocked()
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
timeout := e.calculateShutdownTimeout()
|
||||||
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||||
|
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopLocked tears down everything Start may have brought up, in the order
|
||||||
|
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||||
|
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||||
|
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||||
|
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||||
|
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||||
|
func (e *Engine) stopLocked() {
|
||||||
if e.connMgr != nil {
|
if e.connMgr != nil {
|
||||||
e.connMgr.Close()
|
e.connMgr.Close()
|
||||||
}
|
}
|
||||||
@@ -366,10 +408,6 @@ func (e *Engine) Stop() error {
|
|||||||
// so dbus and friends don't complain because of a missing interface
|
// so dbus and friends don't complain because of a missing interface
|
||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.cancel != nil {
|
|
||||||
e.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
@@ -388,21 +426,6 @@ func (e *Engine) Stop() error {
|
|||||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
timeout := e.calculateShutdownTimeout()
|
|
||||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
|
||||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||||
@@ -440,18 +463,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
|||||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
// The engine is single-use. Reject a duplicate start and a start on an
|
||||||
|
// already-stopped engine (run context cancelled).
|
||||||
|
if e.started {
|
||||||
|
return ErrEngineAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||||
|
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.started = true
|
||||||
|
|
||||||
|
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||||
|
// run context first so goroutines started before the failure (connMgr,
|
||||||
|
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||||
|
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||||
|
// not just what close() covers.
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
e.cancel()
|
||||||
|
e.stopLocked()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.cancel != nil {
|
|
||||||
e.cancel()
|
|
||||||
}
|
|
||||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
|
||||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
@@ -485,13 +528,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("read initial settings: %w", err)
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
@@ -526,7 +567,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
if err = e.wgInterfaceCreate(); err != nil {
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,7 +575,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.createFirewall(); err != nil {
|
if err := e.createFirewall(); err != nil {
|
||||||
e.close()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -547,7 +586,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.udpMux, err = e.wgInterface.Up()
|
e.udpMux, err = e.wgInterface.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -572,9 +610,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.acl = acl.NewDefaultManager(e.firewall)
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.Initialize()
|
if err := e.dnsServer.Initialize(); err != nil {
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +622,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start(peer.IsForceRelayed())
|
e.srWatcher.Start(peer.IsForceRelayed())
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
if err = e.receiveSignalEvents(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveJobEvents()
|
e.receiveJobEvents()
|
||||||
|
|
||||||
@@ -638,7 +676,6 @@ func (e *Engine) createFirewall() error {
|
|||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("set firewall: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -731,7 +768,15 @@ func (e *Engine) blockLanAccess() {
|
|||||||
|
|
||||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
|
||||||
|
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
|
||||||
|
// batch at a time and other subsystems can interleave between passes. It is
|
||||||
|
// passed in (not read globally) so tests can exercise the multi-pass path.
|
||||||
|
const maxPeersPerSyncPass = 300
|
||||||
|
|
||||||
|
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
|
||||||
|
// when more changed peers remained than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
|
|
||||||
// first, check if peers have been modified
|
// first, check if peers have been modified
|
||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
@@ -761,26 +806,32 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
more := false
|
||||||
|
if len(modified) > maxBatch {
|
||||||
|
modified = modified[:maxBatch]
|
||||||
|
more = true
|
||||||
|
}
|
||||||
|
|
||||||
// second, close all modified connections and remove them from the state map
|
// second, close all modified connections and remove them from the state map
|
||||||
for _, p := range modified {
|
for _, p := range modified {
|
||||||
err := e.removePeer(p.GetWgPubKey())
|
if err := e.removePeer(p.GetWgPubKey()); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// third, add the peer connections again
|
// third, add the peer connections again
|
||||||
for _, p := range modified {
|
for _, p := range modified {
|
||||||
err := e.addNewPeer(p)
|
if err := e.addNewPeer(p); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// removePeers removes up to maxBatch peers per call. It returns true when more
|
||||||
|
// peers remained to remove than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
newPeers := make([]string, 0, len(peersUpdate))
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
newPeers = append(newPeers, p.GetWgPubKey())
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
@@ -788,14 +839,19 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
|
|
||||||
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||||
|
|
||||||
|
more := false
|
||||||
|
if len(toRemove) > maxBatch {
|
||||||
|
toRemove = toRemove[:maxBatch]
|
||||||
|
more = true
|
||||||
|
}
|
||||||
|
|
||||||
for _, p := range toRemove {
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
if err := e.removePeer(p); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
log.Infof("removed peer %s", p)
|
log.Infof("removed peer %s", p)
|
||||||
}
|
}
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) removeAllPeers() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
@@ -864,19 +920,17 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
|
||||||
started := time.Now()
|
// returns true if more peers remained than the per-pass cap. It is driven by the
|
||||||
defer func() {
|
// mapStateManager, which re-invokes it (releasing the lock between passes) until
|
||||||
duration := time.Since(started)
|
// the update is fully applied.
|
||||||
log.Infof("sync finished in %s", duration)
|
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
|
||||||
}()
|
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
if e.ctx.Err() != nil {
|
if e.ctx.Err() != nil {
|
||||||
return e.ctx.Err()
|
return false, e.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||||
@@ -884,7 +938,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Posture checks are bound to the network map presence:
|
// Posture checks are bound to the network map presence:
|
||||||
@@ -894,23 +948,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
// leave the previously applied checks untouched
|
// leave the previously applied checks untouched
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil {
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
e.persistSyncResponse(update)
|
|
||||||
|
|
||||||
// only apply new changes and ignore old ones
|
// only apply new changes and ignore old ones
|
||||||
if err := e.updateNetworkMap(nm); err != nil {
|
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
|
||||||
return err
|
if err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||||
|
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||||
@@ -956,6 +1009,13 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
|||||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||||
|
// Only persist updates that carry a network map. Config-only updates (e.g. relay
|
||||||
|
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
|
||||||
|
// the last full map on disk and break restore-on-restart.
|
||||||
|
if update.GetNetworkMap() == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
e.syncRespMux.RLock()
|
e.syncRespMux.RLock()
|
||||||
defer e.syncRespMux.RUnlock()
|
defer e.syncRespMux.RUnlock()
|
||||||
|
|
||||||
@@ -1035,7 +1095,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
}
|
}
|
||||||
e.checks = checks
|
e.checks = checks
|
||||||
|
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
@@ -1066,6 +1126,20 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||||
|
// can be excluded from the reported network addresses; the interface coming and
|
||||||
|
// going otherwise churns the peer meta on the management server.
|
||||||
|
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||||
|
var ips []netip.Addr
|
||||||
|
if e.config.WgAddr.IP.IsValid() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IP)
|
||||||
|
}
|
||||||
|
if e.config.WgAddr.HasIPv6() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IPv6)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
@@ -1209,7 +1283,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
@@ -1233,7 +1307,19 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
// The map-state manager converges the latest update in the background in
|
||||||
|
// bounded passes; the stream callback only hands it the newest target.
|
||||||
|
manager := newMapStateManager(e.applySyncPass, e.persistSyncResponse, func(d time.Duration) {
|
||||||
|
log.Infof("sync finished in %s", d)
|
||||||
|
e.clientMetrics.RecordSyncDuration(e.ctx, d)
|
||||||
|
})
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
|
manager.run(e.ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1284,21 +1370,104 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
|
||||||
|
// and up to maxBatch peers per phase. It returns true when more peers remained
|
||||||
|
// than the cap, so the caller re-runs until convergence.
|
||||||
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
|
||||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||||
if networkMap.GetPeerConfig() != nil {
|
if networkMap.GetPeerConfig() != nil {
|
||||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := networkMap.GetSerial()
|
serial := networkMap.GetSerial()
|
||||||
if e.networkSerial > serial {
|
if e.networkSerial > serial {
|
||||||
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
|
||||||
|
// up-front and only once per target: they are cheap, local, idempotent and must
|
||||||
|
// be in place before peers come up (fail-closed). On the bounded re-runs that only
|
||||||
|
// drain the remaining peer batches they are skipped — the applied forward rules are
|
||||||
|
// reused from e.forwardingRules for the lazy-exclude finalize.
|
||||||
|
if firstPass {
|
||||||
|
e.applyWholesale(networkMap, serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
|
|
||||||
|
// Filter out own peer from the remote peers list
|
||||||
|
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
|
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||||
|
for _, p := range networkMap.GetRemotePeers() {
|
||||||
|
if p.GetWgPubKey() != localPubKey {
|
||||||
|
remotePeers = append(remotePeers, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
|
||||||
|
needMore := false
|
||||||
|
|
||||||
|
// cleanup request, most likely our peer has been deleted
|
||||||
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
|
err := e.removeAllPeers()
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
removeMore, err := e.removePeers(remotePeers, maxBatch)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addMore, err := e.addNewPeers(remotePeers, maxBatch)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
needMore = removeMore || modifyMore || addMore
|
||||||
|
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
|
e.updatePeerSSHHostKeys(remotePeers)
|
||||||
|
|
||||||
|
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the exclude list only once peers have fully converged (this pass added
|
||||||
|
// the last batch). It needs all target peers present in the store, and
|
||||||
|
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
|
||||||
|
if !needMore {
|
||||||
|
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
|
||||||
|
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.networkSerial = serial
|
||||||
|
|
||||||
|
return needMore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
|
||||||
|
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
|
||||||
|
// ingress forward rules — that must be in place before peers come up. It runs once
|
||||||
|
// per target (first pass only); the resulting forward rules are stashed in
|
||||||
|
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
|
||||||
|
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
|
||||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1359,61 +1528,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
}
|
}
|
||||||
|
e.forwardingRules = forwardingRules
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
|
||||||
|
|
||||||
// Filter out own peer from the remote peers list
|
|
||||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
|
||||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
|
||||||
for _, p := range networkMap.GetRemotePeers() {
|
|
||||||
if p.GetWgPubKey() != localPubKey {
|
|
||||||
remotePeers = append(remotePeers, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup request, most likely our peer has been deleted
|
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
|
||||||
err := e.removeAllPeers()
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err := e.removePeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.modifyPeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.addNewPeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
|
|
||||||
e.updatePeerSSHHostKeys(remotePeers)
|
|
||||||
|
|
||||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
|
||||||
log.Warnf("failed to update SSH client config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
|
||||||
}
|
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
|
||||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
|
||||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||||
@@ -1593,14 +1708,23 @@ func addrToString(addr netip.Addr) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
|
||||||
|
// when more new peers remained than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
|
added := 0
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
err := e.addNewPeer(p)
|
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
|
||||||
if err != nil {
|
continue // already present (cheap skip), does not count toward the cap
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
if added >= maxBatch {
|
||||||
|
return true, nil // at least one more new peer remains
|
||||||
|
}
|
||||||
|
if err := e.addNewPeer(p); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
added++
|
||||||
}
|
}
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addNewPeer add peer if connection doesn't exist
|
// addNewPeer add peer if connection doesn't exist
|
||||||
@@ -1698,7 +1822,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
}
|
}
|
||||||
|
|
||||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() error {
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
@@ -1769,7 +1893,12 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
e.signal.WaitStreamConnected()
|
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||||
|
e.signal.WaitStreamConnected(e.ctx)
|
||||||
|
if err := e.ctx.Err(); err != nil {
|
||||||
|
return fmt.Errorf("wait for signal stream: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||||
|
|||||||
565
client/internal/engine_privileged_test.go
Normal file
565
client/internal/engine_privileged_test.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEngine_SSH(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(
|
||||||
|
ctx, cancel,
|
||||||
|
&EngineConfig{
|
||||||
|
WgIfaceName: "utun101",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
ServerSSHAllowed: true,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
SSHKey: sshKey,
|
||||||
|
},
|
||||||
|
EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
},
|
||||||
|
MobileDependency{},
|
||||||
|
)
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.21/24"},
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||||
|
networkMap := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 6,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// SSH server is enabled, therefore SSH config should be applied
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 7,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now remove peer
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 8,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now disable SSH server
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 9,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Sync(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// feed updates to Engine via mocked Management client
|
||||||
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
|
defer close(updates)
|
||||||
|
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
|
for msg := range updates {
|
||||||
|
err := msgHandler(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
|
WgIfaceName: "utun103",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.10/24"},
|
||||||
|
}
|
||||||
|
peer2 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.11/24"},
|
||||||
|
}
|
||||||
|
peer3 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.12/24"},
|
||||||
|
}
|
||||||
|
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||||
|
updates <- &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.After(time.Second * 2)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("timeout while waiting for test to finish")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_MultiplePeers(t *testing.T) {
|
||||||
|
// log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigServer, signalAddr, err := startSignal(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sigServer.Stop()
|
||||||
|
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
|
mu := sync.Mutex{}
|
||||||
|
engines := []*Engine{}
|
||||||
|
numPeers := 10
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(numPeers)
|
||||||
|
// create and start peers
|
||||||
|
for i := 0; i < numPeers; i++ {
|
||||||
|
j := i
|
||||||
|
go func() {
|
||||||
|
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||||
|
if err != nil {
|
||||||
|
wg.Done()
|
||||||
|
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engine.dnsServer = &dns.MockServer{}
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engines = append(engines, engine)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until all have been created and started
|
||||||
|
wg.Wait()
|
||||||
|
if len(engines) != numPeers {
|
||||||
|
t.Fatal("not all peers were started")
|
||||||
|
}
|
||||||
|
// check whether all the peer have expected peers connected
|
||||||
|
|
||||||
|
expectedConnected := numPeers * (numPeers - 1)
|
||||||
|
|
||||||
|
// adjust according to timeouts
|
||||||
|
timeout := 50 * time.Second
|
||||||
|
timeoutChan := time.After(timeout)
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||||
|
break loop
|
||||||
|
case <-ticker.C:
|
||||||
|
totalConnected := 0
|
||||||
|
for _, engine := range engines {
|
||||||
|
totalConnected += getConnectedPeers(engine)
|
||||||
|
}
|
||||||
|
if totalConnected == expectedConnected {
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cleanup test
|
||||||
|
for n, peerEngine := range engines {
|
||||||
|
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||||
|
errStop := peerEngine.mgmClient.Close()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||||
|
}
|
||||||
|
errStop = peerEngine.Stop()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
kaep = keepalive.EnforcementPolicy{
|
||||||
|
MinTime: 15 * time.Second,
|
||||||
|
PermitWithoutStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kasp = keepalive.ServerParameters{
|
||||||
|
MaxConnectionIdle: 15 * time.Second,
|
||||||
|
MaxConnectionAgeGrace: 5 * time.Second,
|
||||||
|
Time: 5 * time.Second,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ifaceName string
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||||
|
} else {
|
||||||
|
ifaceName = fmt.Sprintf("wt%d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgPort := 33100 + i
|
||||||
|
conf := &EngineConfig{
|
||||||
|
WgIfaceName: ifaceName,
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: wgPort,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmtClient,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{}), nil
|
||||||
|
e.ctx = ctx
|
||||||
|
return e, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
config := &config.Config{
|
||||||
|
Stuns: []*config.Host{},
|
||||||
|
TURNConfig: &config.TURNConfig{},
|
||||||
|
Relay: &config.Relay{
|
||||||
|
Addresses: []string{"127.0.0.1:1234"},
|
||||||
|
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||||
|
Secret: "222222222222222222",
|
||||||
|
},
|
||||||
|
Signal: &config.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: "localhost:10000",
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.ExtraSettings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
|
func getConnectedPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
i := 0
|
||||||
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
|
if conn.IsConnected() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
return len(e.peerStore.PeersPubKey())
|
||||||
|
}
|
||||||
@@ -6,37 +6,18 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -50,18 +31,7 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
@@ -69,25 +39,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
kaep = keepalive.EnforcementPolicy{
|
|
||||||
MinTime: 15 * time.Second,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
kasp = keepalive.ServerParameters{
|
|
||||||
MaxConnectionIdle: 15 * time.Second,
|
|
||||||
MaxConnectionAgeGrace: 5 * time.Second,
|
|
||||||
Time: 5 * time.Second,
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockWGIface struct {
|
type MockWGIface struct {
|
||||||
CreateFunc func() error
|
CreateFunc func() error
|
||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||||
@@ -234,129 +188,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(
|
|
||||||
ctx, cancel,
|
|
||||||
&EngineConfig{
|
|
||||||
WgIfaceName: "utun101",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
ServerSSHAllowed: true,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
SSHKey: sshKey,
|
|
||||||
},
|
|
||||||
EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
},
|
|
||||||
MobileDependency{},
|
|
||||||
)
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.21/24"},
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
|
||||||
networkMap := &mgmtProto.NetworkMap{
|
|
||||||
Serial: 6,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// SSH server is enabled, therefore SSH config should be applied
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 7,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
JwtConfig: &mgmtProto.JWTConfig{
|
|
||||||
Issuer: "test-issuer",
|
|
||||||
Audience: "test-audience",
|
|
||||||
KeysLocation: "test-keys",
|
|
||||||
MaxTokenAge: 3600,
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now remove peer
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 8,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now disable SSH server
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 9,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||||
// Test that SSH server start/stop logic works based on config
|
// Test that SSH server start/stop logic works based on config
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
@@ -426,7 +257,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
@@ -602,7 +433,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
|
|
||||||
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
err = engine.updateNetworkMap(c.networkMap)
|
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -629,97 +460,47 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_Sync(t *testing.T) {
|
// chunked apply: with a per-pass cap smaller than the number of peers, a
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
// single updateNetworkMap applies one batch and reports more==true; the
|
||||||
if err != nil {
|
// caller re-runs until convergence. (engine currently holds 0 peers.)
|
||||||
t.Fatal(err)
|
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
|
||||||
return
|
nm := &mgmtProto.NetworkMap{
|
||||||
}
|
Serial: 6,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// feed updates to Engine via mocked Management client
|
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
|
||||||
defer close(updates)
|
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
|
||||||
for msg := range updates {
|
|
||||||
err := msgHandler(msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
|
||||||
WgIfaceName: "utun103",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}, EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{})
|
|
||||||
engine.ctx = ctx
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
peer1 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.10/24"},
|
|
||||||
}
|
|
||||||
peer2 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.11/24"},
|
|
||||||
}
|
|
||||||
peer3 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.12/24"},
|
|
||||||
}
|
|
||||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
|
||||||
updates <- &mgmtProto.SyncResponse{
|
|
||||||
NetworkMap: &mgmtProto.NetworkMap{
|
|
||||||
Serial: 10,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := time.After(time.Second * 2)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
t.Fatalf("timeout while waiting for test to finish")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||||
break
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 1 should signal more")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 2 should signal more")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 2)
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, more, "pass 3 should converge")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 3)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
|
||||||
|
nm := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 7,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, more, "pass 2 should converge")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||||
@@ -817,7 +598,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||||
@@ -890,7 +671,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = engine.updateNetworkMap(testCase.networkMap)
|
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||||
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||||
@@ -1024,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||||
@@ -1094,7 +875,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = engine.updateNetworkMap(testCase.networkMap)
|
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||||
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
||||||
@@ -1105,104 +886,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_MultiplePeers(t *testing.T) {
|
|
||||||
// log.SetLevel(log.DebugLevel)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigServer, signalAddr, err := startSignal(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer sigServer.Stop()
|
|
||||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer mgmtServer.GracefulStop()
|
|
||||||
|
|
||||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
|
||||||
|
|
||||||
mu := sync.Mutex{}
|
|
||||||
engines := []*Engine{}
|
|
||||||
numPeers := 10
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
wg.Add(numPeers)
|
|
||||||
// create and start peers
|
|
||||||
for i := 0; i < numPeers; i++ {
|
|
||||||
j := i
|
|
||||||
go func() {
|
|
||||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
|
||||||
if err != nil {
|
|
||||||
wg.Done()
|
|
||||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engine.dnsServer = &dns.MockServer{}
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
|
||||||
wg.Done()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engines = append(engines, engine)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until all have been created and started
|
|
||||||
wg.Wait()
|
|
||||||
if len(engines) != numPeers {
|
|
||||||
t.Fatal("not all peers was started")
|
|
||||||
}
|
|
||||||
// check whether all the peer have expected peers connected
|
|
||||||
|
|
||||||
expectedConnected := numPeers * (numPeers - 1)
|
|
||||||
|
|
||||||
// adjust according to timeouts
|
|
||||||
timeout := 50 * time.Second
|
|
||||||
timeoutChan := time.After(timeout)
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
|
||||||
break loop
|
|
||||||
case <-ticker.C:
|
|
||||||
totalConnected := 0
|
|
||||||
for _, engine := range engines {
|
|
||||||
totalConnected += getConnectedPeers(engine)
|
|
||||||
}
|
|
||||||
if totalConnected == expectedConnected {
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// cleanup test
|
|
||||||
for n, peerEngine := range engines {
|
|
||||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
|
||||||
errStop := peerEngine.mgmClient.Close()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
|
||||||
}
|
|
||||||
errStop = peerEngine.Stop()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||||
ifaceList, err := net.Interfaces()
|
ifaceList, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1526,187 +1209,6 @@ func TestCompareNetIPLists(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ifaceName string
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
|
||||||
} else {
|
|
||||||
ifaceName = fmt.Sprintf("wt%d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgPort := 33100 + i
|
|
||||||
conf := &EngineConfig{
|
|
||||||
WgIfaceName: ifaceName,
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: wgPort,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
|
||||||
SignalClient: signalClient,
|
|
||||||
MgmClient: mgmtClient,
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{}), nil
|
|
||||||
e.ctx = ctx
|
|
||||||
return e, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
|
||||||
require.NoError(t, err)
|
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
config := &config.Config{
|
|
||||||
Stuns: []*config.Host{},
|
|
||||||
TURNConfig: &config.TURNConfig{},
|
|
||||||
Relay: &config.Relay{
|
|
||||||
Addresses: []string{"127.0.0.1:1234"},
|
|
||||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
|
||||||
Secret: "222222222222222222",
|
|
||||||
},
|
|
||||||
Signal: &config.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: "localhost:10000",
|
|
||||||
},
|
|
||||||
Datadir: dataDir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.Settings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.ExtraSettings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func getConnectedPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
i := 0
|
|
||||||
for _, id := range e.peerStore.PeersPubKey() {
|
|
||||||
conn, _ := e.peerStore.PeerConn(id)
|
|
||||||
if conn.IsConnected() {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
return len(e.peerStore.PeersPubKey())
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
b, err := netiputil.EncodePrefix(p)
|
b, err := netiputil.EncodePrefix(p)
|
||||||
|
|||||||
@@ -119,10 +119,6 @@ func (d *BindListener) ReadPackets() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
|
||||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = d.lazyConn.Close()
|
_ = d.lazyConn.Close()
|
||||||
d.bind.RemoveEndpoint(d.fakeIP)
|
d.bind.RemoveEndpoint(d.fakeIP)
|
||||||
d.done.Done()
|
d.done.Done()
|
||||||
|
|||||||
190
client/internal/mapsync.go
Normal file
190
client/internal/mapsync.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mapStateManager is the single read/write point between the management stream
|
||||||
|
// (writes) and the convergence loop (reads/applies).
|
||||||
|
//
|
||||||
|
// The stream calls SetTarget with the latest full SyncResponse — the complete
|
||||||
|
// desired state. A single background goroutine (run) applies it to the engine in
|
||||||
|
// bounded passes via apply() until converged, releasing syncMsgMux between passes
|
||||||
|
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
|
||||||
|
// coalesces: it keeps converging toward the latest target and the intermediate one
|
||||||
|
// is SKIPPED — never applied on its own (logged, no onConverged).
|
||||||
|
//
|
||||||
|
// Convergence is a single comparison: appliedGen == targetGen. targetGen
|
||||||
|
// increments on every SetTarget (an internal generation counter, so it also covers
|
||||||
|
// config-only updates that carry no network-map serial).
|
||||||
|
//
|
||||||
|
// onConverged fires once for each — and only each — map that is actually processed
|
||||||
|
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
|
||||||
|
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
|
||||||
|
// to a real, completed alignment.
|
||||||
|
type mapStateManager struct {
|
||||||
|
// apply performs one bounded apply pass and reports whether more passes are needed.
|
||||||
|
// firstPass is true on the first pass of a given target, so the caller can run
|
||||||
|
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
|
||||||
|
// re-runs that only drain the bounded peer batches. The manager owns this signal
|
||||||
|
// because it owns the convergence boundary; the engine need not track serials for it.
|
||||||
|
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
|
||||||
|
// onConverged is called once per processed map, with the elapsed time since that
|
||||||
|
// map was received (for the sync-duration metric / "sync finished" log).
|
||||||
|
onConverged func(time.Duration)
|
||||||
|
// persist snapshots an update to disk for restore-on-restart. Called once per
|
||||||
|
// update received from management (in SetTarget), including ones later coalesced
|
||||||
|
// or skipped from apply, so the on-disk state mirrors what management last sent.
|
||||||
|
// The impl skips config-only updates (nil NetworkMap). May be nil.
|
||||||
|
persist func(*mgmProto.SyncResponse)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
target *mgmProto.SyncResponse
|
||||||
|
targetGen uint64
|
||||||
|
appliedGen uint64
|
||||||
|
targetSetAt time.Time
|
||||||
|
|
||||||
|
wake chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
|
||||||
|
return &mapStateManager{
|
||||||
|
apply: apply,
|
||||||
|
persist: persist,
|
||||||
|
onConverged: onConverged,
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTarget records the latest update as the desired state and wakes the loop.
|
||||||
|
// It returns immediately; convergence happens in the background. Serial-based
|
||||||
|
// staleness of the network map is still enforced inside apply (updateNetworkMap).
|
||||||
|
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
// A target that has not settled yet (targetGen > appliedGen) is being superseded
|
||||||
|
// before it converged: we coalesce to the latest map and never apply this one on
|
||||||
|
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
|
||||||
|
if m.target != nil && m.targetGen > m.appliedGen {
|
||||||
|
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
|
||||||
|
}
|
||||||
|
m.target = m.mergeTarget(m.target, update)
|
||||||
|
// Bump an internal generation counter, NOT the map serial: config-only updates
|
||||||
|
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
|
||||||
|
// serial, yet must still be applied. Every SetTarget is therefore a distinct
|
||||||
|
// target regardless of payload. Map-serial staleness is enforced separately
|
||||||
|
// inside apply (updateNetworkMap).
|
||||||
|
m.targetGen++
|
||||||
|
m.targetSetAt = time.Now()
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case m.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist every update received from management — once per update (not per apply
|
||||||
|
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
|
||||||
|
// state always reflects the latest map management sent. Done after waking the loop
|
||||||
|
// so convergence can start in parallel with the disk write. The persist impl skips
|
||||||
|
// config-only updates (nil NetworkMap).
|
||||||
|
if m.persist != nil {
|
||||||
|
m.persist(update)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeTarget combines the currently pending target with a freshly received update
|
||||||
|
// and returns the new desired state. It is called under m.mu from SetTarget and is
|
||||||
|
// the single seam where the replace-vs-squash decision lives.
|
||||||
|
//
|
||||||
|
// Today management always sends a FULL map (the complete desired state), so the
|
||||||
|
// update simply replaces whatever was pending — prev is ignored. When management
|
||||||
|
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
|
||||||
|
// rest of the manager (generation tracking, convergence, signaling) is unaffected
|
||||||
|
// because it already treats target as "the complete desired state, whatever it is".
|
||||||
|
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
|
||||||
|
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
|
||||||
|
func (m *mapStateManager) run(ctx context.Context) {
|
||||||
|
// passGen is the generation of the most recent apply() call (0 = none). A pass is
|
||||||
|
// the first for its target when its generation differs from the previous one —
|
||||||
|
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
|
||||||
|
var passGen uint64
|
||||||
|
for {
|
||||||
|
m.mu.Lock()
|
||||||
|
target, tg, ag := m.target, m.targetGen, m.appliedGen
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Fully converged (or nothing yet): block until a new target arrives.
|
||||||
|
if target == nil || ag == tg {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-m.wake:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firstPass := tg != passGen
|
||||||
|
passGen = tg
|
||||||
|
more, err := m.apply(target, firstPass)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Log and DROP this target — do not retry it. A deterministic failure
|
||||||
|
// (e.g. a malformed peer in the map) would otherwise spin every pass
|
||||||
|
// making no progress. Management is the source of truth and re-delivers
|
||||||
|
// the full map on the next sync, so dropping is safe; peers already
|
||||||
|
// applied this convergence stay (idempotent diffs) and the remainder is
|
||||||
|
// reconciled by the next target. Mirrors the legacy handleSync path,
|
||||||
|
// where the apply error was logged by the gRPC client and the update
|
||||||
|
// dropped. No onConverged: this target did not converge.
|
||||||
|
log.Errorf("apply sync pass, dropping update: %v", err)
|
||||||
|
m.settle(tg, false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if more {
|
||||||
|
// keep converging the current target; syncMsgMux was released by apply
|
||||||
|
// between passes so other subsystems interleave.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// This pass converged. Mark applied and signal this one map.
|
||||||
|
m.settle(tg, true)
|
||||||
|
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
|
||||||
|
// ag<tg next iteration -> apply it; this generation was skipped (logged in
|
||||||
|
// SetTarget) and is not signaled.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// settle marks generation tg as processed so the loop goes idle instead of
|
||||||
|
// re-applying the same target. It is a no-op when a newer target arrived during the
|
||||||
|
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
|
||||||
|
// just-finished generation was already counted as skipped.
|
||||||
|
//
|
||||||
|
// When signal is true (the pass converged) it fires onConverged once for this map;
|
||||||
|
// when false (the target was dropped on error) it does not — the map did not converge.
|
||||||
|
func (m *mapStateManager) settle(tg uint64, signal bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.targetGen != tg {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.appliedGen = tg
|
||||||
|
setAt := m.targetSetAt
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if signal && m.onConverged != nil {
|
||||||
|
m.onConverged(time.Since(setAt))
|
||||||
|
}
|
||||||
|
}
|
||||||
242
client/internal/mapsync_test.go
Normal file
242
client/internal/mapsync_test.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// converges over the bounded passes (apply returns more until the 3rd pass),
|
||||||
|
// fires onConverged exactly once, then blocks (no further apply) until a new target.
|
||||||
|
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
|
||||||
|
var passes int32
|
||||||
|
var firstPasses int32
|
||||||
|
converged := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||||
|
n := atomic.AddInt32(&passes, 1)
|
||||||
|
if firstPass {
|
||||||
|
atomic.AddInt32(&firstPasses, 1)
|
||||||
|
}
|
||||||
|
return n < 3, nil // more on pass 1 and 2, converge on pass 3
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("manager did not converge")
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
|
||||||
|
|
||||||
|
// once converged the loop blocks: no further apply calls
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist runs once per received update (not per apply pass), regardless of how many
|
||||||
|
// bounded passes that target takes to converge.
|
||||||
|
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
|
||||||
|
var passes, persists int32
|
||||||
|
converged := make(chan struct{}, 1)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
n := atomic.AddInt32(&passes, 1)
|
||||||
|
return n < 3, nil // 3 passes for one target
|
||||||
|
}
|
||||||
|
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||||
|
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not converge")
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// every update received from management is persisted — even one that is coalesced /
|
||||||
|
// skipped from apply before it ever converges.
|
||||||
|
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
|
||||||
|
release := make(chan struct{})
|
||||||
|
var persists int32
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
<-release // hold the first apply so the second update coalesces/skips
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||||
|
m := newMapStateManager(apply, persist, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
|
||||||
|
close(release)
|
||||||
|
|
||||||
|
// both updates persisted even though map1 is skipped from apply
|
||||||
|
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// each map that is actually processed (converged before the next arrives) fires
|
||||||
|
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
|
||||||
|
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
|
||||||
|
converged := make(chan struct{}, 8)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
const maps = 3
|
||||||
|
for i := 0; i < maps; i++ {
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select { // wait for this map to converge before sending the next (no coalescing)
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("map %d not signaled", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// no extra signals once the stream goes quiet
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
t.Fatal("unexpected extra onConverged")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// a map superseded before it converges is skipped: only the latest (processed) map
|
||||||
|
// fires onConverged, not the skipped one.
|
||||||
|
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
|
||||||
|
release := make(chan struct{})
|
||||||
|
var applies, converged atomic.Int32
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applies.Add(1)
|
||||||
|
<-release // hold the first apply in-flight so we can queue a newer target
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
// map1 is picked up; its apply blocks on release
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
|
||||||
|
|
||||||
|
// map2 supersedes map1 before it settled -> map1 is skipped
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
close(release) // let both applies proceed
|
||||||
|
|
||||||
|
// only the processed (latest) map signals; the skipped one does not
|
||||||
|
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
|
||||||
|
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// an apply error drops the target: no retry of the same target, no onConverged,
|
||||||
|
// the loop goes idle — and a fresh target is still applied afterwards.
|
||||||
|
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
|
||||||
|
applied := make(chan struct{}, 8)
|
||||||
|
var failNext atomic.Bool
|
||||||
|
failNext.Store(true)
|
||||||
|
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applied <- struct{}{}
|
||||||
|
if failNext.Load() {
|
||||||
|
return false, errors.New("boom")
|
||||||
|
}
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
var converged atomic.Int32
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
// first target errors -> applied once, then dropped (no retry, no onConverged)
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("errored target not applied")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
t.Fatal("errored target must not be retried")
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
|
||||||
|
|
||||||
|
// a new target is still processed normally and converges
|
||||||
|
failNext.Store(false)
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("new target after error not applied")
|
||||||
|
}
|
||||||
|
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a new target after convergence triggers a fresh apply; an idle (converged)
|
||||||
|
// manager does not apply on its own.
|
||||||
|
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
|
||||||
|
applied := make(chan struct{}, 8)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applied <- struct{}{}
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("first target not applied")
|
||||||
|
}
|
||||||
|
|
||||||
|
// converged → must stay idle (no spurious apply)
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
t.Fatal("unexpected apply while idle/converged")
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("new target not applied")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
offer := h.buildOfferAnswer()
|
offer := h.buildOfferAnswer()
|
||||||
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
|
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
|
||||||
|
|
||||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) sendAnswer() error {
|
func (h *Handshaker) sendAnswer() error {
|
||||||
answer := h.buildOfferAnswer()
|
answer := h.buildOfferAnswer()
|
||||||
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
|
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
|
||||||
|
|
||||||
return h.signaler.SignalAnswer(answer, h.config.Key)
|
return h.signaler.SignalAnswer(answer, h.config.Key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
|||||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||||
type Status struct {
|
type Status struct {
|
||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
|
muxRelays sync.RWMutex
|
||||||
peers map[string]State
|
peers map[string]State
|
||||||
ipToKey map[string]string
|
ipToKey map[string]string
|
||||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||||
@@ -244,8 +245,8 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||||
d.mux.Lock()
|
d.muxRelays.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.muxRelays.Unlock()
|
||||||
d.relayMgr = manager
|
d.relayMgr = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -906,8 +907,8 @@ func (d *Status) MarkSignalConnected() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||||
d.mux.Lock()
|
d.muxRelays.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.muxRelays.Unlock()
|
||||||
d.relayStates = relayResults
|
d.relayStates = relayResults
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1018,24 +1019,26 @@ func (d *Status) GetSignalState() SignalState {
|
|||||||
|
|
||||||
// GetRelayStates returns the stun/turn/permanent relay states
|
// GetRelayStates returns the stun/turn/permanent relay states
|
||||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
d.mux.RLock()
|
d.muxRelays.RLock()
|
||||||
defer d.mux.RUnlock()
|
|
||||||
if d.relayMgr == nil {
|
if d.relayMgr == nil {
|
||||||
return d.relayStates
|
defer d.muxRelays.RUnlock()
|
||||||
|
return slices.Clone(d.relayStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayMgr := d.relayMgr
|
||||||
// extend the list of stun, turn servers with the relay server connections
|
// extend the list of stun, turn servers with the relay server connections
|
||||||
relayStates := slices.Clone(d.relayStates)
|
relayStates := slices.Clone(d.relayStates)
|
||||||
|
d.muxRelays.RUnlock()
|
||||||
|
|
||||||
states := d.relayMgr.RelayStates()
|
states := relayMgr.RelayStates()
|
||||||
if len(states) == 0 {
|
if len(states) == 0 {
|
||||||
// no relay connection tracked yet; surface configured servers as
|
// no relay connection tracked yet; surface configured servers as
|
||||||
// unavailable with the real reconnect error when known
|
// unavailable with the real reconnect error when known
|
||||||
err := relayClient.ErrRelayClientNotConnected
|
err := relayClient.ErrRelayClientNotConnected
|
||||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||||
err = connErr
|
err = connErr
|
||||||
}
|
}
|
||||||
for _, r := range d.relayMgr.ServerURLs() {
|
for _, r := range relayMgr.ServerURLs() {
|
||||||
relayStates = append(relayStates, relay.ProbeResult{
|
relayStates = append(relayStates, relay.ProbeResult{
|
||||||
URI: r,
|
URI: r,
|
||||||
Err: err,
|
Err: err,
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
|
||||||
if *input.ServerSSHAllowed {
|
if *input.ServerSSHAllowed {
|
||||||
log.Infof("enabling SSH server")
|
log.Infof("enabling SSH server")
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -242,6 +242,35 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
|
||||||
|
// Configs written before ServerSSHAllowed was introduced lack the field and
|
||||||
|
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
|
||||||
|
// apply the value instead of panicking on a nil pointer dereference.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input *bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"enable", util.True(), true},
|
||||||
|
{"disable", util.False(), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
|
||||||
|
|
||||||
|
config, err := UpdateConfig(ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
ServerSSHAllowed: tt.input,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
|
||||||
|
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateOldManagementURL(t *testing.T) {
|
func TestUpdateOldManagementURL(t *testing.T) {
|
||||||
origProber := newMgmProber
|
origProber := newMgmProber
|
||||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||||
|
|||||||
@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||||
|
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||||
|
// the original client did not request EDNS0.
|
||||||
|
hadEdns := r.IsEdns0() != nil
|
||||||
|
if !hadEdns {
|
||||||
|
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
}
|
||||||
|
|
||||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||||
|
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||||
|
}
|
||||||
|
if !hadEdns {
|
||||||
|
resutil.StripOPT(reply)
|
||||||
|
}
|
||||||
|
|
||||||
resutil.SetMeta(w, "peer", peerKey)
|
resutil.SetMeta(w, "peer", peerKey)
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEntryExists(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||||
|
|
||||||
|
content := []string{
|
||||||
|
"1000 reserved",
|
||||||
|
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||||
|
"9999 other_table",
|
||||||
|
}
|
||||||
|
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||||
|
|
||||||
|
file, err := os.Open(tempFilePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
assert.NoError(t, file.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id int
|
||||||
|
shouldExist bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ExistsWithNetbirdPrefix",
|
||||||
|
id: 7120,
|
||||||
|
shouldExist: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ExistsWithDifferentName",
|
||||||
|
id: 1000,
|
||||||
|
shouldExist: true,
|
||||||
|
err: ErrTableIDExists,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DoesNotExist",
|
||||||
|
id: 1234,
|
||||||
|
shouldExist: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
exists, err := entryExists(file, tc.id)
|
||||||
|
if tc.err != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.shouldExist, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRoutes(t *testing.T) {
|
||||||
|
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
var intf *net.Interface
|
||||||
|
var nexthop Nexthop
|
||||||
|
|
||||||
|
_, intf = setupDummyInterface(t)
|
||||||
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
|
r := New(nil, nil)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||||
|
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||||
|
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||||
|
require.NoError(t, err, "Failed to create loopback alias")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||||
|
})
|
||||||
|
|
||||||
|
return intf
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||||
|
require.NoError(t, err, "Failed to parse prefix")
|
||||||
|
|
||||||
|
netIntf, err := net.InterfaceByName(intf)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
|
r := New(nil, nil)
|
||||||
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := r.removeFromRouteTable(prefix, nexthop)
|
||||||
|
assert.NoError(t, err, "Failed to remove route from table")
|
||||||
|
})
|
||||||
|
|
||||||
|
return intf
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var originalNexthop net.IP
|
||||||
|
if dstCIDR == "0.0.0.0/0" {
|
||||||
|
var err error
|
||||||
|
originalNexthop, err = fetchOriginalGateway()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||||
|
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if originalNexthop != nil {
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||||
|
assert.NoError(t, err, "Failed to restore original route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||||
|
require.NoError(t, err, "Failed to add route")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove route")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (net.IP, error) {
|
||||||
|
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, fmt.Errorf("gateway not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(matches[1]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||||
|
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||||
|
|
||||||
|
tunName := strings.TrimSpace(string(output))
|
||||||
|
|
||||||
|
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||||
|
|
||||||
|
intf, err := net.InterfaceByName(tunName)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||||
|
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||||
|
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||||
|
|
||||||
|
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||||
|
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||||
|
}
|
||||||
@@ -3,79 +3,24 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files in this package compile.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedVPNint = "utun100"
|
var expectedVPNint = "utun100"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedExternalInt = "lo0"
|
var expectedExternalInt = "lo0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedInternalInt = "lo0"
|
var expectedInternalInt = "lo0"
|
||||||
|
|
||||||
func init() {
|
|
||||||
testCases = append(testCases, []testCase{
|
|
||||||
{
|
|
||||||
name: "To more specific route without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrentRoutes(t *testing.T) {
|
|
||||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
|
||||||
|
|
||||||
var intf *net.Interface
|
|
||||||
var nexthop Nexthop
|
|
||||||
|
|
||||||
_, intf = setupDummyInterface(t)
|
|
||||||
nexthop = Nexthop{netip.Addr{}, intf}
|
|
||||||
|
|
||||||
r := New(nil, nil)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 1024; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(ip netip.Addr) {
|
|
||||||
defer wg.Done()
|
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
|
||||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
|
||||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
|
||||||
}
|
|
||||||
}(baseIP)
|
|
||||||
baseIP = baseIP.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
|
||||||
|
|
||||||
for i := 0; i < 1024; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(ip netip.Addr) {
|
|
||||||
defer wg.Done()
|
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
|
||||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
|
||||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
|
||||||
}
|
|
||||||
}(baseIP)
|
|
||||||
baseIP = baseIP.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -122,122 +67,3 @@ func TestBits(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
|
||||||
require.NoError(t, err, "Failed to create loopback alias")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
|
||||||
})
|
|
||||||
|
|
||||||
return intf
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
|
||||||
require.NoError(t, err, "Failed to parse prefix")
|
|
||||||
|
|
||||||
netIntf, err := net.InterfaceByName(intf)
|
|
||||||
require.NoError(t, err, "Failed to get interface by name")
|
|
||||||
|
|
||||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
|
||||||
|
|
||||||
r := New(nil, nil)
|
|
||||||
err = r.addToRouteTable(prefix, nexthop)
|
|
||||||
require.NoError(t, err, "Failed to add route to table")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := r.removeFromRouteTable(prefix, nexthop)
|
|
||||||
assert.NoError(t, err, "Failed to remove route from table")
|
|
||||||
})
|
|
||||||
|
|
||||||
return intf
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
var originalNexthop net.IP
|
|
||||||
if dstCIDR == "0.0.0.0/0" {
|
|
||||||
var err error
|
|
||||||
originalNexthop, err = fetchOriginalGateway()
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Failed to fetch original gateway: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
|
||||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if originalNexthop != nil {
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
|
||||||
assert.NoError(t, err, "Failed to restore original route")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
|
||||||
require.NoError(t, err, "Failed to add route")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove route")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOriginalGateway() (net.IP, error) {
|
|
||||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return nil, fmt.Errorf("gateway not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.ParseIP(matches[1]), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
|
||||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
|
||||||
|
|
||||||
tunName := strings.TrimSpace(string(output))
|
|
||||||
|
|
||||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
|
||||||
|
|
||||||
intf, err := net.InterfaceByName(tunName)
|
|
||||||
require.NoError(t, err, "Failed to get interface by name")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
|
||||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
|
||||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
|
||||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
type dialer interface {
|
||||||
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android && !ios
|
//go:build !android && !ios && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
@@ -26,11 +26,6 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
|
||||||
Dial(network, address string) (net.Conn, error)
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddVPNRoute(t *testing.T) {
|
func TestAddVPNRoute(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -515,125 +510,3 @@ func setupTestEnv(t *testing.T) {
|
|||||||
// unique route in vpn table
|
// unique route in vpn table
|
||||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsVpnRoute(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
addr string
|
|
||||||
vpnRoutes []string
|
|
||||||
localRoutes []string
|
|
||||||
expectedVpn bool
|
|
||||||
expectedPrefix netip.Prefix
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Match in VPN routes",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Match in local routes",
|
|
||||||
addr: "10.1.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No match",
|
|
||||||
addr: "172.16.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.Prefix{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Default route ignored",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Default route matches but ignored",
|
|
||||||
addr: "172.16.1.1",
|
|
||||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.Prefix{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match local",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16"},
|
|
||||||
localRoutes: []string{"192.168.1.0/24"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match local multiple",
|
|
||||||
addr: "192.168.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
|
||||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match vpn",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"192.168.0.0/16"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match vpn multiple",
|
|
||||||
addr: "192.168.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
|
||||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Duplicate prefix in both",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"192.168.1.0/24"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr, err := netip.ParseAddr(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var vpnRoutes, localRoutes []netip.Prefix
|
|
||||||
for _, route := range tt.vpnRoutes {
|
|
||||||
prefix, err := netip.ParsePrefix(route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
|
||||||
}
|
|
||||||
vpnRoutes = append(vpnRoutes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range tt.localRoutes {
|
|
||||||
prefix, err := netip.ParsePrefix(route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
|
||||||
}
|
|
||||||
localRoutes = append(localRoutes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
|
||||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
|
||||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsVpnRoute(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
vpnRoutes []string
|
||||||
|
localRoutes []string
|
||||||
|
expectedVpn bool
|
||||||
|
expectedPrefix netip.Prefix
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Match in VPN routes",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match in local routes",
|
||||||
|
addr: "10.1.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
addr: "172.16.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.Prefix{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default route ignored",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default route matches but ignored",
|
||||||
|
addr: "172.16.1.1",
|
||||||
|
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.Prefix{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match local",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16"},
|
||||||
|
localRoutes: []string{"192.168.1.0/24"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match local multiple",
|
||||||
|
addr: "192.168.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||||
|
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match vpn",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"192.168.0.0/16"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match vpn multiple",
|
||||||
|
addr: "192.168.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||||
|
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate prefix in both",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"192.168.1.0/24"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, err := netip.ParseAddr(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vpnRoutes, localRoutes []netip.Prefix
|
||||||
|
for _, route := range tt.vpnRoutes {
|
||||||
|
prefix, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||||
|
}
|
||||||
|
vpnRoutes = append(vpnRoutes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range tt.localRoutes {
|
||||||
|
prefix, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||||
|
}
|
||||||
|
localRoutes = append(localRoutes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||||
|
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||||
|
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,10 @@
|
|||||||
//go:build !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -18,10 +15,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
)
|
)
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
|
||||||
var expectedExternalInt = "dummyext0"
|
|
||||||
var expectedInternalInt = "dummyint0"
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
testCases = append(testCases, []testCase{
|
testCases = append(testCases, []testCase{
|
||||||
{
|
{
|
||||||
@@ -33,62 +26,6 @@ func init() {
|
|||||||
}...)
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEntryExists(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
|
||||||
|
|
||||||
content := []string{
|
|
||||||
"1000 reserved",
|
|
||||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
|
||||||
"9999 other_table",
|
|
||||||
}
|
|
||||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
|
||||||
|
|
||||||
file, err := os.Open(tempFilePath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
assert.NoError(t, file.Close())
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id int
|
|
||||||
shouldExist bool
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ExistsWithNetbirdPrefix",
|
|
||||||
id: 7120,
|
|
||||||
shouldExist: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ExistsWithDifferentName",
|
|
||||||
id: 1000,
|
|
||||||
shouldExist: true,
|
|
||||||
err: ErrTableIDExists,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DoesNotExist",
|
|
||||||
id: 1234,
|
|
||||||
shouldExist: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
exists, err := entryExists(file, tc.id)
|
|
||||||
if tc.err != nil {
|
|
||||||
assert.ErrorIs(t, err, tc.err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.shouldExist, exists)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files in this package compile.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedVPNint = "wgtest0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedExternalInt = "dummyext0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedInternalInt = "dummyint0"
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||||
|
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||||
|
// BSD/darwin test files compile without the privileged build tag.
|
||||||
|
|
||||||
|
type PacketExpectation struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
SrcPort int
|
||||||
|
DstPort int
|
||||||
|
UDP bool
|
||||||
|
TCP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
expectedPacket PacketExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||||
|
return PacketExpectation{
|
||||||
|
SrcIP: net.ParseIP(srcIP),
|
||||||
|
DstIP: net.ParseIP(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
UDP: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
@@ -20,63 +20,6 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketExpectation struct {
|
|
||||||
SrcIP net.IP
|
|
||||||
DstIP net.IP
|
|
||||||
SrcPort int
|
|
||||||
DstPort int
|
|
||||||
UDP bool
|
|
||||||
TCP bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type testCase struct {
|
|
||||||
name string
|
|
||||||
expectedInterface string
|
|
||||||
dialer dialer
|
|
||||||
expectedPacket PacketExpectation
|
|
||||||
}
|
|
||||||
|
|
||||||
var testCases = []testCase{
|
|
||||||
{
|
|
||||||
name: "To external host without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To external host with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To unique vpn route with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To unique vpn route without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
nbnet.Init()
|
nbnet.Init()
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -102,16 +45,6 @@ func TestRouting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
|
||||||
return PacketExpectation{
|
|
||||||
SrcIP: net.ParseIP(srcIP),
|
|
||||||
DstIP: net.ParseIP(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
UDP: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build windows && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||||
// without v6 connectivity. If a default already exists it is left alone.
|
// without v6 connectivity. If a default already exists it is left alone.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||||
|
|
||||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||||
// without v6 connectivity. If a default already exists it is left alone.
|
// without v6 connectivity. If a default already exists it is left alone.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ type URLOpener interface {
|
|||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
cfgPath string
|
cfgPath string
|
||||||
}
|
}
|
||||||
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||||
|
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||||
|
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||||
|
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||||
|
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||||
|
// process (decoupled from the network extension), so without this the server
|
||||||
|
// lingers after the user dismisses the browser and the next connect stalls
|
||||||
|
// trying to bind the same port.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
return &Auth{
|
return &Auth{
|
||||||
ctx: context.Background(),
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
cfgPath: cfgPath,
|
cfgPath: cfgPath,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
|||||||
|
|
||||||
// NewAuthWithConfig instantiate Auth based on existing config
|
// NewAuthWithConfig instantiate Auth based on existing config
|
||||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
return &Auth{
|
return &Auth{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||||
|
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||||
|
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||||
|
// safe to call when no login is running.
|
||||||
|
func (a *Auth) Stop() {
|
||||||
|
if a.cancel != nil {
|
||||||
|
a.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
|
|||||||
@@ -993,6 +993,10 @@ func (s *Server) cleanupConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||||
|
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||||
|
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||||
|
// making the run loop the sole owner of engine shutdown.
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
if err := engine.Stop(); err != nil {
|
if err := engine.Stop(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
235
client/server/server_privileged_test.go
Normal file
235
client/server/server_privileged_test.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kaep = keepalive.EnforcementPolicy{
|
||||||
|
MinTime: 15 * time.Second,
|
||||||
|
PermitWithoutStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kasp = keepalive.ServerParameters{
|
||||||
|
MaxConnectionIdle: 15 * time.Second,
|
||||||
|
MaxConnectionAgeGrace: 5 * time.Second,
|
||||||
|
Time: 5 * time.Second,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||||
|
// we will use a management server started via to simulate the server and capture the number of retries
|
||||||
|
func TestConnectWithRetryRuns(t *testing.T) {
|
||||||
|
// start the signal server
|
||||||
|
_, signalAddr, err := startSignal(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start signal server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
counter := 0
|
||||||
|
// start the management server
|
||||||
|
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start management server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||||
|
defer cancel()
|
||||||
|
// create new server
|
||||||
|
ic := profilemanager.ConfigInput{
|
||||||
|
ManagementURL: "http://" + mgmtAddr,
|
||||||
|
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pm := profilemanager.ServiceManager{}
|
||||||
|
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||||
|
ID: "test-profile",
|
||||||
|
Username: currUser.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(ctx, "debug", "", false, false, false, false)
|
||||||
|
|
||||||
|
s.config = config
|
||||||
|
|
||||||
|
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||||
|
t.Setenv(retryInitialIntervalVar, "1s")
|
||||||
|
t.Setenv(maxRetryIntervalVar, "2s")
|
||||||
|
t.Setenv(maxRetryTimeVar, "5s")
|
||||||
|
t.Setenv(retryMultiplierVar, "1")
|
||||||
|
|
||||||
|
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||||
|
if counter < 3 {
|
||||||
|
t.Fatalf("expected counter > 2, got %d", counter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockServer struct {
|
||||||
|
mgmtProto.ManagementServiceServer
|
||||||
|
counter *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||||
|
*m.counter++
|
||||||
|
return m.ManagementServiceServer.Login(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
dataDir := t.TempDir()
|
||||||
|
|
||||||
|
config := &config.Config{
|
||||||
|
Stuns: []*config.Host{},
|
||||||
|
TURNConfig: &config.TURNConfig{},
|
||||||
|
Signal: &config.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: signalAddr,
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||||
|
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mock := &mockServer{
|
||||||
|
ManagementServiceServer: mgmtServer,
|
||||||
|
counter: counter,
|
||||||
|
}
|
||||||
|
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
@@ -2,124 +2,22 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
kaep = keepalive.EnforcementPolicy{
|
|
||||||
MinTime: 15 * time.Second,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
kasp = keepalive.ServerParameters{
|
|
||||||
MaxConnectionIdle: 15 * time.Second,
|
|
||||||
MaxConnectionAgeGrace: 5 * time.Second,
|
|
||||||
Time: 5 * time.Second,
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
|
||||||
// we will use a management server started via to simulate the server and capture the number of retries
|
|
||||||
func TestConnectWithRetryRuns(t *testing.T) {
|
|
||||||
// start the signal server
|
|
||||||
_, signalAddr, err := startSignal(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to start signal server: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
counter := 0
|
|
||||||
// start the management server
|
|
||||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to start management server: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
|
||||||
|
|
||||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
|
||||||
defer cancel()
|
|
||||||
// create new server
|
|
||||||
ic := profilemanager.ConfigInput{
|
|
||||||
ManagementURL: "http://" + mgmtAddr,
|
|
||||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
currUser, err := user.Current()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
pm := profilemanager.ServiceManager{}
|
|
||||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
|
||||||
ID: "test-profile",
|
|
||||||
Username: currUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to set active profile state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := New(ctx, "debug", "", false, false, false, false)
|
|
||||||
|
|
||||||
s.config = config
|
|
||||||
|
|
||||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
|
||||||
t.Setenv(retryInitialIntervalVar, "1s")
|
|
||||||
t.Setenv(maxRetryIntervalVar, "2s")
|
|
||||||
t.Setenv(maxRetryTimeVar, "5s")
|
|
||||||
t.Setenv(retryMultiplierVar, "1")
|
|
||||||
|
|
||||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
|
||||||
if counter < 3 {
|
|
||||||
t.Fatalf("expected counter > 2, got %d", counter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_Up(t *testing.T) {
|
func TestServer_Up(t *testing.T) {
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||||
@@ -259,119 +157,3 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockServer struct {
|
|
||||||
mgmtProto.ManagementServiceServer
|
|
||||||
counter *int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
|
||||||
*m.counter++
|
|
||||||
return m.ManagementServiceServer.Login(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
dataDir := t.TempDir()
|
|
||||||
|
|
||||||
config := &config.Config{
|
|
||||||
Stuns: []*config.Host{},
|
|
||||||
TURNConfig: &config.TURNConfig{},
|
|
||||||
Signal: &config.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: signalAddr,
|
|
||||||
},
|
|
||||||
Datadir: dataDir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
|
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
|
||||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
|
||||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mock := &mockServer{
|
|
||||||
ManagementServiceServer: mgmtServer,
|
|
||||||
counter: counter,
|
|
||||||
}
|
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
|
||||||
require.NoError(t, err)
|
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|||||||
118
client/ssh/client/client_privileged_test.go
Normal file
118
client/ssh/client/client_privileged_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
|
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, _, client := setupTestSSHServerAndClient(t)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||||
|
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(output), "hello")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||||
|
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("commands with flags work", func(t *testing.T) {
|
||||||
|
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||||
|
var testCmd string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
testCmd = "echo hello | Select-String notfound"
|
||||||
|
} else {
|
||||||
|
testCmd = "echo 'hello' | grep 'notfound'"
|
||||||
|
}
|
||||||
|
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||||
|
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("connection with short timeout", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Check for actual timeout-related errors rather than string matching
|
||||||
|
assert.True(t,
|
||||||
|
errors.Is(err, context.DeadlineExceeded) ||
|
||||||
|
errors.Is(err, context.Canceled) ||
|
||||||
|
strings.Contains(err.Error(), "timeout"),
|
||||||
|
"Expected timeout-related error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command execution cancellation", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
t.Logf("client close error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cmdCancel()
|
||||||
|
|
||||||
|
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||||
|
if err != nil {
|
||||||
|
var exitMissingErr *cryptossh.ExitMissingError
|
||||||
|
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||||
|
errors.Is(err, context.Canceled) ||
|
||||||
|
errors.As(err, &exitMissingErr)
|
||||||
|
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
@@ -78,53 +77,6 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
|||||||
assert.NotNil(t, client.client)
|
assert.NotNil(t, client.client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
|
||||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
|
||||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
|
||||||
}
|
|
||||||
|
|
||||||
server, _, client := setupTestSSHServerAndClient(t)
|
|
||||||
defer func() {
|
|
||||||
err := server.Stop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
err := client.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
|
||||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(output), "hello")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
|
||||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("commands with flags work", func(t *testing.T) {
|
|
||||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
|
||||||
var testCmd string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
testCmd = "echo hello | Select-String notfound"
|
|
||||||
} else {
|
|
||||||
testCmd = "echo 'hello' | grep 'notfound'"
|
|
||||||
}
|
|
||||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -154,59 +106,6 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
|
||||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
|
||||||
defer func() {
|
|
||||||
err := server.Stop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Run("connection with short timeout", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
currentUser := testutil.GetTestUsername(t)
|
|
||||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// Check for actual timeout-related errors rather than string matching
|
|
||||||
assert.True(t,
|
|
||||||
errors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
errors.Is(err, context.Canceled) ||
|
|
||||||
strings.Contains(err.Error(), "timeout"),
|
|
||||||
"Expected timeout-related error, got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("command execution cancellation", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
currentUser := testutil.GetTestUsername(t)
|
|
||||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
if err := client.Close(); err != nil {
|
|
||||||
t.Logf("client close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cmdCancel()
|
|
||||||
|
|
||||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
|
||||||
if err != nil {
|
|
||||||
var exitMissingErr *cryptossh.ExitMissingError
|
|
||||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
errors.Is(err, context.Canceled) ||
|
|
||||||
errors.As(err, &exitMissingErr)
|
|
||||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *mockDaemon) setJWTToken(token string) {
|
||||||
|
m.impl.jwtToken = token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHProxy_Connect(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &server.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: &server.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audiences: []string{audience},
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer := server.New(serverConfig)
|
||||||
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
// Configure SSH authorization for the test user
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
|
defer func() { _ = sshServer.Stop() }()
|
||||||
|
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
defer mockDaemon.stop()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
origStdin := os.Stdin
|
||||||
|
origStdout := os.Stdout
|
||||||
|
defer func() {
|
||||||
|
os.Stdin = origStdin
|
||||||
|
os.Stdout = origStdout
|
||||||
|
}()
|
||||||
|
|
||||||
|
stdinReader, stdinWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
os.Stdin = stdinReader
|
||||||
|
os.Stdout = stdoutWriter
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
connectErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
connectErrCh <- proxyInstance.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sshConfig := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||||
|
require.NoError(t, err, "Should connect to proxy server")
|
||||||
|
defer func() { _ = sshClientConn.Close() }()
|
||||||
|
|
||||||
|
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||||
|
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||||
|
|
||||||
|
outputCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
output, err := session.Output("echo hello-from-proxy")
|
||||||
|
outputCh <- output
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case output := <-outputCh:
|
||||||
|
err := <-errCh
|
||||||
|
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||||
|
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("Command execution timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = session.Close()
|
||||||
|
_ = sshClient.Close()
|
||||||
|
_ = clientConn.Close()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||||
|
// when forwarding commands to the backend. This is critical for tools like
|
||||||
|
// Ansible that send commands such as:
|
||||||
|
//
|
||||||
|
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||||
|
//
|
||||||
|
// The single quotes must be preserved so the backend shell receives the
|
||||||
|
// subshell expression as a single argument to -c.
|
||||||
|
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClient, cleanup := setupProxySSHClient(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||||
|
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||||
|
// the local shell strips the outer single quotes, and the SSH exec request
|
||||||
|
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||||
|
//
|
||||||
|
// The proxy must forward this string verbatim. Using session.Command()
|
||||||
|
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||||
|
// the command on the backend.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "subshell_in_double_quotes",
|
||||||
|
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||||
|
expect: "from-subshell\nouter\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "printf_with_special_chars",
|
||||||
|
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||||
|
expect: "hello world\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested_command_substitution",
|
||||||
|
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||||
|
expect: "nested\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = session.Close() }()
|
||||||
|
|
||||||
|
var stderrBuf bytes.Buffer
|
||||||
|
session.Stderr = &stderrBuf
|
||||||
|
|
||||||
|
outputCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
output, err := session.Output(tc.command)
|
||||||
|
outputCh <- output
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case output := <-outputCh:
|
||||||
|
err := <-errCh
|
||||||
|
if stderrBuf.Len() > 0 {
|
||||||
|
t.Logf("stderr: %s", stderrBuf.String())
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||||
|
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatalf("command timed out: %s", tc.command)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupProxySSHClient creates a full proxy test environment and returns
|
||||||
|
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||||
|
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &server.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: &server.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audiences: []string{audience},
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer := server.New(serverConfig)
|
||||||
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
|
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
origStdin := os.Stdin
|
||||||
|
origStdout := os.Stdout
|
||||||
|
|
||||||
|
stdinReader, stdinWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
os.Stdin = stdinReader
|
||||||
|
os.Stdout = stdoutWriter
|
||||||
|
|
||||||
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
|
||||||
|
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||||
|
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = proxyInstance.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sshConfig := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||||
|
|
||||||
|
cleanupFn := func() {
|
||||||
|
_ = client.Close()
|
||||||
|
_ = clientConn.Close()
|
||||||
|
cancel()
|
||||||
|
os.Stdin = origStdin
|
||||||
|
os.Stdout = origStdout
|
||||||
|
_ = sshServer.Stop()
|
||||||
|
mockDaemon.stop()
|
||||||
|
jwksServer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, cleanupFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, jwksJSON := generateTestJWKS(t)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if _, err := w.Write(jwksJSON); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return server, privateKey, server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
n := publicKey.N.Bytes()
|
||||||
|
e := publicKey.E
|
||||||
|
|
||||||
|
jwk := nbjwt.JSONWebKey{
|
||||||
|
Kty: "RSA",
|
||||||
|
Kid: "test-key-id",
|
||||||
|
Use: "sig",
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(n),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := nbjwt.Jwks{
|
||||||
|
Keys: []nbjwt.JSONWebKey{jwk},
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksJSON, err := json.Marshal(jwks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return privateKey, jwksJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||||
|
t.Helper()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": user,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
@@ -1,25 +1,12 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
@@ -28,11 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/server"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
|
||||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -106,331 +89,6 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHProxy_Connect(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
issuer = "https://test-issuer.example.com"
|
|
||||||
audience = "test-audience"
|
|
||||||
)
|
|
||||||
|
|
||||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
||||||
defer jwksServer.Close()
|
|
||||||
|
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverConfig := &server.Config{
|
|
||||||
HostKeyPEM: hostKey,
|
|
||||||
JWT: &server.JWTConfig{
|
|
||||||
Issuer: issuer,
|
|
||||||
Audiences: []string{audience},
|
|
||||||
KeysLocation: jwksURL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer := server.New(serverConfig)
|
|
||||||
sshServer.SetAllowRootLogin(true)
|
|
||||||
|
|
||||||
// Configure SSH authorization for the test user
|
|
||||||
testUsername := testutil.GetTestUsername(t)
|
|
||||||
testJWTUser := "test-username"
|
|
||||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authConfig := &sshauth.Config{
|
|
||||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
||||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
||||||
MachineUsers: map[string][]uint32{
|
|
||||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer.UpdateSSHAuth(authConfig)
|
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
||||||
defer func() { _ = sshServer.Stop() }()
|
|
||||||
|
|
||||||
mockDaemon := startMockDaemon(t)
|
|
||||||
defer mockDaemon.stop()
|
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
||||||
mockDaemon.setJWTToken(validToken)
|
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
|
||||||
defer func() { _ = clientConn.Close() }()
|
|
||||||
|
|
||||||
origStdin := os.Stdin
|
|
||||||
origStdout := os.Stdout
|
|
||||||
defer func() {
|
|
||||||
os.Stdin = origStdin
|
|
||||||
os.Stdout = origStdout
|
|
||||||
}()
|
|
||||||
|
|
||||||
stdinReader, stdinWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
os.Stdin = stdinReader
|
|
||||||
os.Stdout = stdoutWriter
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
connectErrCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
connectErrCh <- proxyInstance.Connect(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
sshConfig := &cryptossh.ClientConfig{
|
|
||||||
User: testutil.GetTestUsername(t),
|
|
||||||
Auth: []cryptossh.AuthMethod{},
|
|
||||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
||||||
require.NoError(t, err, "Should connect to proxy server")
|
|
||||||
defer func() { _ = sshClientConn.Close() }()
|
|
||||||
|
|
||||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
||||||
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
|
||||||
|
|
||||||
outputCh := make(chan []byte, 1)
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
output, err := session.Output("echo hello-from-proxy")
|
|
||||||
outputCh <- output
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case output := <-outputCh:
|
|
||||||
err := <-errCh
|
|
||||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
|
||||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("Command execution timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = session.Close()
|
|
||||||
_ = sshClient.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
|
||||||
// when forwarding commands to the backend. This is critical for tools like
|
|
||||||
// Ansible that send commands such as:
|
|
||||||
//
|
|
||||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
|
||||||
//
|
|
||||||
// The single quotes must be preserved so the backend shell receives the
|
|
||||||
// subshell expression as a single argument to -c.
|
|
||||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClient, cleanup := setupProxySSHClient(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
|
||||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
|
||||||
// the local shell strips the outer single quotes, and the SSH exec request
|
|
||||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
|
||||||
//
|
|
||||||
// The proxy must forward this string verbatim. Using session.Command()
|
|
||||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
|
||||||
// the command on the backend.
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
command string
|
|
||||||
expect string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "subshell_in_double_quotes",
|
|
||||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
|
||||||
expect: "from-subshell\nouter\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "printf_with_special_chars",
|
|
||||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
|
||||||
expect: "hello world\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nested_command_substitution",
|
|
||||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
|
||||||
expect: "nested\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = session.Close() }()
|
|
||||||
|
|
||||||
var stderrBuf bytes.Buffer
|
|
||||||
session.Stderr = &stderrBuf
|
|
||||||
|
|
||||||
outputCh := make(chan []byte, 1)
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
output, err := session.Output(tc.command)
|
|
||||||
outputCh <- output
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case output := <-outputCh:
|
|
||||||
err := <-errCh
|
|
||||||
if stderrBuf.Len() > 0 {
|
|
||||||
t.Logf("stderr: %s", stderrBuf.String())
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
|
||||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatalf("command timed out: %s", tc.command)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupProxySSHClient creates a full proxy test environment and returns
|
|
||||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
|
||||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
const (
|
|
||||||
issuer = "https://test-issuer.example.com"
|
|
||||||
audience = "test-audience"
|
|
||||||
)
|
|
||||||
|
|
||||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
||||||
|
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverConfig := &server.Config{
|
|
||||||
HostKeyPEM: hostKey,
|
|
||||||
JWT: &server.JWTConfig{
|
|
||||||
Issuer: issuer,
|
|
||||||
Audiences: []string{audience},
|
|
||||||
KeysLocation: jwksURL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer := server.New(serverConfig)
|
|
||||||
sshServer.SetAllowRootLogin(true)
|
|
||||||
|
|
||||||
testUsername := testutil.GetTestUsername(t)
|
|
||||||
testJWTUser := "test-username"
|
|
||||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authConfig := &sshauth.Config{
|
|
||||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
||||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
||||||
MachineUsers: map[string][]uint32{
|
|
||||||
testUsername: {0},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer.UpdateSSHAuth(authConfig)
|
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
||||||
|
|
||||||
mockDaemon := startMockDaemon(t)
|
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
||||||
mockDaemon.setJWTToken(validToken)
|
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
origStdin := os.Stdin
|
|
||||||
origStdout := os.Stdout
|
|
||||||
|
|
||||||
stdinReader, stdinWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
os.Stdin = stdinReader
|
|
||||||
os.Stdout = stdoutWriter
|
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
|
||||||
|
|
||||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
|
||||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = proxyInstance.Connect(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
sshConfig := &cryptossh.ClientConfig{
|
|
||||||
User: testutil.GetTestUsername(t),
|
|
||||||
Auth: []cryptossh.AuthMethod{},
|
|
||||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
||||||
|
|
||||||
cleanupFn := func() {
|
|
||||||
_ = client.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
cancel()
|
|
||||||
os.Stdin = origStdin
|
|
||||||
os.Stdout = origStdout
|
|
||||||
_ = sshServer.Stop()
|
|
||||||
mockDaemon.stop()
|
|
||||||
jwksServer.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, cleanupFn
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockDaemonServer struct {
|
type mockDaemonServer struct {
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
hostKeys map[string][]byte
|
hostKeys map[string][]byte
|
||||||
@@ -492,10 +150,6 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
|||||||
m.impl.hostKeys[addr] = pubKey
|
m.impl.hostKeys[addr] = pubKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockDaemon) setJWTToken(token string) {
|
|
||||||
m.impl.jwtToken = token
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDaemon) stop() {
|
func (m *mockDaemon) stop() {
|
||||||
if m.server != nil {
|
if m.server != nil {
|
||||||
m.server.Stop()
|
m.server.Stop()
|
||||||
@@ -508,63 +162,3 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return pubKey
|
return pubKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
|
||||||
t.Helper()
|
|
||||||
privateKey, jwksJSON := generateTestJWKS(t)
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if _, err := w.Write(jwksJSON); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
return server, privateKey, server.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
|
||||||
t.Helper()
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
publicKey := &privateKey.PublicKey
|
|
||||||
n := publicKey.N.Bytes()
|
|
||||||
e := publicKey.E
|
|
||||||
|
|
||||||
jwk := nbjwt.JSONWebKey{
|
|
||||||
Kty: "RSA",
|
|
||||||
Kid: "test-key-id",
|
|
||||||
Use: "sig",
|
|
||||||
N: base64.RawURLEncoding.EncodeToString(n),
|
|
||||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
|
||||||
}
|
|
||||||
|
|
||||||
jwks := nbjwt.Jwks{
|
|
||||||
Keys: []nbjwt.JSONWebKey{jwk},
|
|
||||||
}
|
|
||||||
|
|
||||||
jwksJSON, err := json.Marshal(jwks)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return privateKey, jwksJSON
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
|
||||||
t.Helper()
|
|
||||||
claims := jwt.MapClaims{
|
|
||||||
"iss": issuer,
|
|
||||||
"aud": audience,
|
|
||||||
"sub": user,
|
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
|
||||||
"iat": time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
||||||
token.Header["kid"] = "test-key-id"
|
|
||||||
|
|
||||||
tokenString, err := token.SignedString(privateKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return tokenString
|
|
||||||
}
|
|
||||||
|
|||||||
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build unix && privileged
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||||
|
pd := NewPrivilegeDropper()
|
||||||
|
|
||||||
|
config := ExecutorConfig{
|
||||||
|
UID: 1000,
|
||||||
|
GID: 1000,
|
||||||
|
Groups: []uint32{1000, 1001},
|
||||||
|
WorkingDir: "/home/testuser",
|
||||||
|
Shell: "/bin/bash",
|
||||||
|
Command: "ls -la",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, cmd)
|
||||||
|
|
||||||
|
// Verify the command is calling netbird ssh exec
|
||||||
|
assert.Contains(t, cmd.Args, "ssh")
|
||||||
|
assert.Contains(t, cmd.Args, "exec")
|
||||||
|
assert.Contains(t, cmd.Args, "--uid")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "--gid")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "--groups")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "1001")
|
||||||
|
assert.Contains(t, cmd.Args, "--working-dir")
|
||||||
|
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||||
|
assert.Contains(t, cmd.Args, "--shell")
|
||||||
|
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||||
|
assert.Contains(t, cmd.Args, "--cmd")
|
||||||
|
assert.Contains(t, cmd.Args, "ls -la")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||||
|
pd := NewPrivilegeDropper()
|
||||||
|
|
||||||
|
config := ExecutorConfig{
|
||||||
|
UID: 1000,
|
||||||
|
GID: 1000,
|
||||||
|
Groups: []uint32{1000},
|
||||||
|
WorkingDir: "/home/testuser",
|
||||||
|
Shell: "/bin/bash",
|
||||||
|
Command: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, cmd)
|
||||||
|
|
||||||
|
// Verify no command mode (command is empty so no --cmd flag)
|
||||||
|
assert.NotContains(t, cmd.Args, "--cmd")
|
||||||
|
assert.NotContains(t, cmd.Args, "--interactive")
|
||||||
|
}
|
||||||
@@ -73,61 +73,6 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
|
||||||
pd := NewPrivilegeDropper()
|
|
||||||
|
|
||||||
config := ExecutorConfig{
|
|
||||||
UID: 1000,
|
|
||||||
GID: 1000,
|
|
||||||
Groups: []uint32{1000, 1001},
|
|
||||||
WorkingDir: "/home/testuser",
|
|
||||||
Shell: "/bin/bash",
|
|
||||||
Command: "ls -la",
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, cmd)
|
|
||||||
|
|
||||||
// Verify the command is calling netbird ssh exec
|
|
||||||
assert.Contains(t, cmd.Args, "ssh")
|
|
||||||
assert.Contains(t, cmd.Args, "exec")
|
|
||||||
assert.Contains(t, cmd.Args, "--uid")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "--gid")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "--groups")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "1001")
|
|
||||||
assert.Contains(t, cmd.Args, "--working-dir")
|
|
||||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
|
||||||
assert.Contains(t, cmd.Args, "--shell")
|
|
||||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
|
||||||
assert.Contains(t, cmd.Args, "--cmd")
|
|
||||||
assert.Contains(t, cmd.Args, "ls -la")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
|
||||||
pd := NewPrivilegeDropper()
|
|
||||||
|
|
||||||
config := ExecutorConfig{
|
|
||||||
UID: 1000,
|
|
||||||
GID: 1000,
|
|
||||||
Groups: []uint32{1000},
|
|
||||||
WorkingDir: "/home/testuser",
|
|
||||||
Shell: "/bin/bash",
|
|
||||||
Command: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, cmd)
|
|
||||||
|
|
||||||
// Verify no command mode (command is empty so no --cmd flag)
|
|
||||||
assert.NotContains(t, cmd.Args, "--cmd")
|
|
||||||
assert.NotContains(t, cmd.Args, "--interactive")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||||
// This test requires root privileges and will be skipped if not running as root
|
// This test requires root privileges and will be skipped if not running as root
|
||||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package system
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -121,6 +122,23 @@ func (i *Info) SetFlags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeAddresses drops network addresses whose IP matches any of the given
|
||||||
|
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
|
||||||
|
// address, which otherwise churns the meta as the interface comes and goes.
|
||||||
|
func (i *Info) removeAddresses(ips ...netip.Addr) {
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filtered := i.NetworkAddresses[:0]
|
||||||
|
for _, addr := range i.NetworkAddresses {
|
||||||
|
if slices.Contains(ips, addr.NetIP.Addr()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, addr)
|
||||||
|
}
|
||||||
|
i.NetworkAddresses = filtered
|
||||||
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
@@ -147,7 +165,9 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
// excludeIPs are dropped from the reported network addresses (e.g. our own
|
||||||
|
// WireGuard overlay address, which otherwise churns the peer meta).
|
||||||
|
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
|
||||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||||
processCheckPaths := make([]string, 0)
|
processCheckPaths := make([]string, 0)
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
@@ -162,6 +182,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
|||||||
|
|
||||||
info := GetInfo(ctx)
|
info := GetInfo(ctx)
|
||||||
info.Files = files
|
info.Files = files
|
||||||
|
info.removeAddresses(excludeIPs...)
|
||||||
|
|
||||||
log.Debugf("all system information gathered successfully")
|
log.Debugf("all system information gathered successfully")
|
||||||
return info, nil
|
return info, nil
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -43,3 +44,42 @@ func Test_NetAddresses(t *testing.T) {
|
|||||||
t.Errorf("no network addresses found")
|
t.Errorf("no network addresses found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInfo_RemoveAddresses(t *testing.T) {
|
||||||
|
addr := func(cidr string) NetworkAddress {
|
||||||
|
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &Info{
|
||||||
|
NetworkAddresses: []NetworkAddress{
|
||||||
|
addr("192.168.1.7/24"),
|
||||||
|
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
|
||||||
|
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
|
||||||
|
addr("fd00:1234::1/64"), // overlay v6
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
|
||||||
|
info.removeAddresses(
|
||||||
|
netip.MustParseAddr("100.76.70.97"),
|
||||||
|
netip.MustParseAddr("fd00:1234::1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
|
||||||
|
if len(info.NetworkAddresses) != len(want) {
|
||||||
|
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
|
||||||
|
}
|
||||||
|
for i, w := range want {
|
||||||
|
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
|
||||||
|
t.Errorf("address[%d] = %s, want %s", i, got, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
|
||||||
|
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
|
||||||
|
info.removeAddresses()
|
||||||
|
if len(info.NetworkAddresses) != 1 {
|
||||||
|
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return NetworkAddress{}, false
|
return NetworkAddress{}, false
|
||||||
}
|
}
|
||||||
if ipNet.IP.IsLoopback() {
|
// Skip link-local and multicast: they carry no routable peer info and the
|
||||||
|
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
|
||||||
|
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||||
return NetworkAddress{}, false
|
return NetworkAddress{}, false
|
||||||
}
|
}
|
||||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||||
|
|||||||
45
client/system/network_addr_test.go
Normal file
45
client/system/network_addr_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
|
||||||
|
t.Helper()
|
||||||
|
ip, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse %q: %v", cidr, err)
|
||||||
|
}
|
||||||
|
ipNet.IP = ip
|
||||||
|
return ipNet
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToNetworkAddress_Filtering(t *testing.T) {
|
||||||
|
const mac = "c8:4b:d6:b6:04:ac"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cidr string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"ipv4 global", "10.65.16.181/23", true},
|
||||||
|
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
|
||||||
|
{"ipv4 loopback", "127.0.0.1/8", false},
|
||||||
|
{"ipv6 loopback", "::1/128", false},
|
||||||
|
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
|
||||||
|
{"ipv4 link-local", "169.254.1.2/16", false},
|
||||||
|
{"ipv6 multicast", "ff02::1/128", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
196
client/testutil/privileged/runner_test.go
Normal file
196
client/testutil/privileged/runner_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
//go:build privileged && (linux || darwin)
|
||||||
|
|
||||||
|
// Package privileged provides a self-hosting harness that runs the repo's
|
||||||
|
// privileged-tagged test suite inside a --privileged --cap-add=NET_ADMIN
|
||||||
|
// container, so developers can exercise the root/system-mutating tests on a
|
||||||
|
// non-root host with a single `go test` invocation.
|
||||||
|
package privileged
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/moby/moby/api/types/container"
|
||||||
|
"github.com/ory/dockertest/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// containerImage / containerTag match the image used by the CI privileged job
|
||||||
|
// (.github/workflows/golang-test-linux.yml, test_client_on_docker).
|
||||||
|
const (
|
||||||
|
containerImage = "golang"
|
||||||
|
containerTag = "1.25-alpine"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
containerWorkdir = "/app"
|
||||||
|
containerGoCache = "/root/.cache/go-build"
|
||||||
|
containerGoModCache = "/go/pkg/mod"
|
||||||
|
)
|
||||||
|
|
||||||
|
// alpinePackages are the build/runtime deps the privileged tests need, mirroring
|
||||||
|
// the CI container setup.
|
||||||
|
const alpinePackages = "ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base"
|
||||||
|
|
||||||
|
// privilegedTestPackages is the package list the suite runs, excluding the
|
||||||
|
// server-side trees and UI/upload helpers, matching the CI Docker job's filter.
|
||||||
|
const privilegedTestPackages = `go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server`
|
||||||
|
|
||||||
|
// testWriter forwards container output to the test log line by line.
|
||||||
|
type testWriter struct{ t *testing.T }
|
||||||
|
|
||||||
|
func (w testWriter) Write(p []byte) (int, error) {
|
||||||
|
for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
|
||||||
|
w.t.Log(line)
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRunPrivilegedSuiteInDocker spins up a privileged container, mounts the repo,
|
||||||
|
// and runs `go test -tags 'devcert privileged'` inside it. When already running
|
||||||
|
// inside that container (DOCKER_CI=true) it returns immediately so the real
|
||||||
|
// privileged tests in the suite execute in place instead of recursing.
|
||||||
|
func TestRunPrivilegedSuiteInDocker(t *testing.T) {
|
||||||
|
if os.Getenv("DOCKER_CI") == "true" {
|
||||||
|
t.Skip("inside privileged container, skipping container spawn; privileged tests run in place")
|
||||||
|
}
|
||||||
|
|
||||||
|
repoRoot, err := findRepoRoot()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("locate repo root: %v", err)
|
||||||
|
}
|
||||||
|
goCache, goModCache := hostGoCaches(t)
|
||||||
|
|
||||||
|
// dockertest reads DOCKER_HOST; point it at the active context's socket when
|
||||||
|
// the default one is absent (macOS Docker Desktop, Colima, OrbStack).
|
||||||
|
if host := dockerHost(); host != "" {
|
||||||
|
t.Setenv("DOCKER_HOST", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolT registers container cleanup via t.Cleanup automatically.
|
||||||
|
pool := dockertest.NewPoolT(t, "", dockertest.WithMaxWait(30*time.Minute))
|
||||||
|
|
||||||
|
// Keep the container alive so the suite runs via Exec, which yields a clean
|
||||||
|
// exit code (the v4 Resource API exposes no container wait/exit-code).
|
||||||
|
resource := pool.RunT(t, containerImage,
|
||||||
|
dockertest.WithTag(containerTag),
|
||||||
|
dockertest.WithWorkingDir(containerWorkdir),
|
||||||
|
dockertest.WithMounts([]string{
|
||||||
|
repoRoot + ":" + containerWorkdir,
|
||||||
|
goCache + ":" + containerGoCache,
|
||||||
|
goModCache + ":" + containerGoModCache,
|
||||||
|
}),
|
||||||
|
dockertest.WithEnv([]string{
|
||||||
|
"CGO_ENABLED=1",
|
||||||
|
"CI=true",
|
||||||
|
"DOCKER_CI=true",
|
||||||
|
"CONTAINER=true",
|
||||||
|
"GOCACHE=" + containerGoCache,
|
||||||
|
"GOMODCACHE=" + containerGoModCache,
|
||||||
|
}),
|
||||||
|
dockertest.WithCmd([]string{"sleep", "infinity"}),
|
||||||
|
dockertest.WithHostConfig(func(hc *container.HostConfig) {
|
||||||
|
hc.Privileged = true
|
||||||
|
hc.CapAdd = []string{"NET_ADMIN"}
|
||||||
|
}),
|
||||||
|
dockertest.WithoutReuse(),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result, err := resource.Exec(ctx, []string{"sh", "-c", buildTestScript()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run privileged suite in container: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := testWriter{t}
|
||||||
|
_, _ = w.Write([]byte(result.StdOut))
|
||||||
|
_, _ = w.Write([]byte(result.StdErr))
|
||||||
|
|
||||||
|
if result.ExitCode != 0 {
|
||||||
|
t.Fatalf("privileged test suite failed in container (exit code %d)", result.ExitCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRepoRoot walks up from the test's working directory to the module root.
|
||||||
|
func findRepoRoot() (string, error) {
|
||||||
|
dir, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil {
|
||||||
|
return dir, nil
|
||||||
|
}
|
||||||
|
parent := filepath.Dir(dir)
|
||||||
|
if parent == dir {
|
||||||
|
return "", fmt.Errorf("go.mod not found above %s", dir)
|
||||||
|
}
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dockerHost returns a DOCKER_HOST override when the default socket is missing.
|
||||||
|
// An empty result means the caller should leave DOCKER_HOST untouched (it is
|
||||||
|
// already set, or the default unix socket exists). When neither is present
|
||||||
|
// (common on macOS Docker Desktop, Colima and OrbStack, which use a per-user
|
||||||
|
// socket), it resolves the active docker context's endpoint.
|
||||||
|
func dockerHost() string {
|
||||||
|
if os.Getenv("DOCKER_HOST") != "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, err := os.Stat("/var/run/docker.sock"); err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command("docker", "context", "inspect", "-f", "{{.Endpoints.docker.Host}}").Output()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostGoCaches resolves the host GOCACHE/GOMODCACHE so the container reuses the
|
||||||
|
// existing build/module cache for speed.
|
||||||
|
func hostGoCaches(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
return goEnv(t, "GOCACHE"), goEnv(t, "GOMODCACHE")
|
||||||
|
}
|
||||||
|
|
||||||
|
func goEnv(t *testing.T, key string) string {
|
||||||
|
t.Helper()
|
||||||
|
var out bytes.Buffer
|
||||||
|
cmd := exec.Command("go", "env", key)
|
||||||
|
cmd.Stdout = &out
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatalf("go env %s: %v", key, err)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTestScript builds the in-container command. PRIV_PKGS overrides the package
|
||||||
|
// list (default: the full filtered set); PRIV_RUN adds a -run test-name filter.
|
||||||
|
// Both empty reproduces the full privileged suite.
|
||||||
|
func buildTestScript() string {
|
||||||
|
pkgs := privilegedTestPackages + " | xargs"
|
||||||
|
if p := os.Getenv("PRIV_PKGS"); p != "" {
|
||||||
|
pkgs = "echo " + p + " | xargs"
|
||||||
|
}
|
||||||
|
|
||||||
|
runFilter := ""
|
||||||
|
if r := os.Getenv("PRIV_RUN"); r != "" {
|
||||||
|
runFilter = "-run '" + r + "' "
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"apk update >/dev/null && apk add --no-cache %s >/dev/null && %s go test -buildvcs=false -tags 'devcert privileged' %s-v -timeout 20m -p 1",
|
||||||
|
alpinePackages, pkgs, runFilter,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
case args.showProfiles:
|
case args.showProfiles:
|
||||||
s.showProfilesUI()
|
s.showProfilesUI()
|
||||||
case args.showQuickActions:
|
case args.showQuickActions:
|
||||||
s.showQuickActionsUI()
|
// Suppress the on-boot Quick Actions popup when the daemon
|
||||||
|
// reports DisableAutoConnect=true — that flag carries both the
|
||||||
|
// user's "Connect on Startup = off" preference AND any MDM-
|
||||||
|
// enforced override (applyMDMPolicy writes the policy value
|
||||||
|
// into the same Config field). See netbirdio/netbird#5744.
|
||||||
|
if !s.disableAutoConnectFromDaemon() {
|
||||||
|
s.showQuickActionsUI()
|
||||||
|
}
|
||||||
case args.showUpdate:
|
case args.showUpdate:
|
||||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||||
}
|
}
|
||||||
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
|||||||
return features, nil
|
return features, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||||
|
// the active profile has DisableAutoConnect=true. Used by the
|
||||||
|
// --quick-actions startup path to suppress the on-boot popup when the
|
||||||
|
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||||
|
// converge on the same Config field because applyMDMPolicy writes the
|
||||||
|
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||||
|
// daemon hiccup does not silently swallow the popup.
|
||||||
|
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||||
|
activeProf, err := s.profileManager.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
currUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn, err := s.getSrvClient(failFastTimeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||||
|
ProfileName: activeProf.ID.String(),
|
||||||
|
Username: currUser.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return srvCfg.GetDisableAutoConnect()
|
||||||
|
}
|
||||||
|
|
||||||
// getSrvConfig from the service to show it in the settings window.
|
// getSrvConfig from the service to show it in the settings window.
|
||||||
func (s *serviceClient) getSrvConfig() {
|
func (s *serviceClient) getSrvConfig() {
|
||||||
s.managementURL = profilemanager.DefaultManagementURL
|
s.managementURL = profilemanager.DefaultManagementURL
|
||||||
|
|||||||
78
docs/testing-privileged.md
Normal file
78
docs/testing-privileged.md
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# Privileged tests
|
||||||
|
|
||||||
|
Some tests in this repo need `root` or mutate host network state: they create
|
||||||
|
TUN/WireGuard interfaces, open netlink/raw sockets, run eBPF programs, or shell
|
||||||
|
out to `ip`/`iptables`/`nft`/`ifconfig`/`route`. Running them on a developer
|
||||||
|
machine would require `sudo` and could leave stray interfaces or routes behind.
|
||||||
|
|
||||||
|
These tests are gated behind the **`privileged` build tag** so the default test
|
||||||
|
run is host-safe.
|
||||||
|
|
||||||
|
## Running tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Host-safe: excludes privileged tests. Runs as a normal user, no sudo.
|
||||||
|
make test-unit
|
||||||
|
# equivalently:
|
||||||
|
go test -tags devcert ./...
|
||||||
|
|
||||||
|
# Privileged suite: runs the privileged-tagged tests inside a
|
||||||
|
# --privileged --cap-add=NET_ADMIN container (requires Docker).
|
||||||
|
make test-privileged
|
||||||
|
|
||||||
|
# Narrow the container run to a single test / package:
|
||||||
|
PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||||
|
```
|
||||||
|
|
||||||
|
`PRIV_RUN` adds a `-run` test-name filter and `PRIV_PKGS` overrides the package
|
||||||
|
list; both are optional and default to the full privileged suite.
|
||||||
|
|
||||||
|
`make test-privileged` invokes the `ory/dockertest` harness in
|
||||||
|
`client/testutil/privileged/`. The harness:
|
||||||
|
|
||||||
|
1. Skips immediately when it detects it is already inside the container
|
||||||
|
(`DOCKER_CI=true`), so the privileged tests run in place instead of recursing.
|
||||||
|
2. Otherwise spins up a `golang:1.25-alpine` container (matching CI),
|
||||||
|
bind-mounts the repo and the host Go build/module caches, installs the
|
||||||
|
required packages, and runs `go test -tags 'devcert privileged'` over the
|
||||||
|
client packages.
|
||||||
|
3. Streams the container's output to the test log and fails if the suite fails.
|
||||||
|
|
||||||
|
## Adding a privileged test
|
||||||
|
|
||||||
|
A test is privileged if it does any of:
|
||||||
|
|
||||||
|
- creates a real interface via `iface.NewWGIFace(...).Create()`,
|
||||||
|
- opens a netlink or raw socket that hard-fails without `CAP_NET_ADMIN`,
|
||||||
|
- runs an eBPF program (`ebpf.*.Listen()`),
|
||||||
|
- shells out to `ip`, `iptables`, `nft`, `ifconfig`, or `route` to change state.
|
||||||
|
|
||||||
|
Add the tag to the **top** of the file, combined with any existing platform
|
||||||
|
constraint:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//go:build privileged && linux
|
||||||
|
|
||||||
|
package foo
|
||||||
|
```
|
||||||
|
|
||||||
|
If a file mixes privileged and pure-logic tests, **split it**: keep the pure
|
||||||
|
tests (and any shared data — type/var declarations, table-driven `testCases`,
|
||||||
|
helper interfaces) in an untagged file, and move the privileged tests into a
|
||||||
|
`*_privileged_test.go` file with the tag. Shared declarations must stay untagged,
|
||||||
|
otherwise the unprivileged files in the package will not compile.
|
||||||
|
|
||||||
|
Always verify both build modes compile on every target platform:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go vet -tags devcert ./...
|
||||||
|
go vet -tags 'devcert privileged' ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
## CI
|
||||||
|
|
||||||
|
- The `Client / Unit` job runs `go test -tags devcert` with **no** `sudo` — only
|
||||||
|
host-safe tests.
|
||||||
|
- The `Client (Docker) / Unit` job runs `go test -tags 'devcert privileged'`
|
||||||
|
inside a `--privileged --cap-add=NET_ADMIN` container, which is where the
|
||||||
|
privileged tests actually execute.
|
||||||
11
go.mod
11
go.mod
@@ -78,10 +78,12 @@ require (
|
|||||||
github.com/mdp/qrterminal/v3 v3.2.1
|
github.com/mdp/qrterminal/v3 v3.2.1
|
||||||
github.com/miekg/dns v1.1.72
|
github.com/miekg/dns v1.1.72
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
|
github.com/moby/moby/api v1.54.1
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/oapi-codegen/runtime v1.1.2
|
github.com/oapi-codegen/runtime v1.1.2
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
|
github.com/ory/dockertest/v4 v4.0.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
|
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
|
||||||
@@ -145,7 +147,7 @@ require (
|
|||||||
dario.cat/mergo v1.0.1 // indirect
|
dario.cat/mergo v1.0.1 // indirect
|
||||||
filippo.io/edwards25519 v1.1.1 // indirect
|
filippo.io/edwards25519 v1.1.1 // indirect
|
||||||
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
|
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
|
||||||
github.com/Azure/go-ntlmssp v0.1.0 // indirect
|
github.com/Azure/go-ntlmssp v0.1.0 // indirect
|
||||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||||
@@ -177,6 +179,8 @@ require (
|
|||||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
@@ -271,11 +275,12 @@ require (
|
|||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/moby/moby/client v0.4.0 // indirect
|
||||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||||
github.com/moby/sys/sequential v0.5.0 // indirect
|
github.com/moby/sys/sequential v0.5.0 // indirect
|
||||||
github.com/moby/sys/user v0.3.0 // indirect
|
github.com/moby/sys/user v0.3.0 // indirect
|
||||||
github.com/moby/sys/userns v0.1.0 // indirect
|
github.com/moby/sys/userns v0.1.0 // indirect
|
||||||
github.com/moby/term v0.5.0 // indirect
|
github.com/moby/term v0.5.2 // indirect
|
||||||
github.com/morikuni/aec v1.0.0 // indirect
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||||
@@ -341,7 +346,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
28
go.sum
28
go.sum
@@ -23,8 +23,8 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9
|
|||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||||
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
|
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
|
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
|
||||||
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
|
||||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
@@ -117,6 +117,10 @@ github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
|
|||||||
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
|
||||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||||
@@ -480,6 +484,10 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx
|
|||||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4=
|
||||||
|
github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs=
|
||||||
|
github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw=
|
||||||
|
github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g=
|
||||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||||
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
|
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
|
||||||
@@ -488,8 +496,8 @@ github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo=
|
|||||||
github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
@@ -510,8 +518,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
|
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a h1:3CWK+yTvRKOcC0Q8VCTGy4l60TEb27CQVS7LkMxwjmw=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||||
@@ -542,6 +550,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
|||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/ory/dockertest/v4 v4.0.0 h1:i19aFsO/VXE0VrMk4ifnKW4G/KIJ93PCjLOslxXoPME=
|
||||||
|
github.com/ory/dockertest/v4 v4.0.0/go.mod h1:b5Ofu8VIxWNhXFvQcLu17pRNQdoUBKtXBW74G4Ygzx8=
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
@@ -973,11 +983,13 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
|
|||||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||||
|
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
|
||||||
|
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
|
||||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||||
|
|||||||
616
infrastructure_files/getting-started-enterprise.sh
Executable file
616
infrastructure_files/getting-started-enterprise.sh
Executable file
@@ -0,0 +1,616 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# NetBird Enterprise — Getting Started
|
||||||
|
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
|
||||||
|
# embedded identity provider. Owner is created via first-login flow.
|
||||||
|
|
||||||
|
SED_STRIP_PADDING='s/=//g'
|
||||||
|
|
||||||
|
check_docker_compose() {
|
||||||
|
if command -v docker-compose &> /dev/null; then
|
||||||
|
echo "docker-compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if docker compose --help &> /dev/null; then
|
||||||
|
echo "docker compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_openssl() {
|
||||||
|
if ! command -v openssl &> /dev/null; then
|
||||||
|
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_secret() {
|
||||||
|
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_b64_key() {
|
||||||
|
openssl rand -base64 32
|
||||||
|
}
|
||||||
|
|
||||||
|
check_nb_domain() {
|
||||||
|
local domain="$1"
|
||||||
|
if [[ -z "$domain" ]]; then
|
||||||
|
echo "The domain cannot be empty." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ "$domain" == "netbird.example.com" ]]; then
|
||||||
|
echo "The domain cannot be netbird.example.com" > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
|
||||||
|
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
|
||||||
|
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
check_domain_resolves() {
|
||||||
|
local domain="$1"
|
||||||
|
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
|
||||||
|
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
|
||||||
|
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
|
||||||
|
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
read_nb_domain() {
|
||||||
|
local value=""
|
||||||
|
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if ! check_nb_domain "$value"; then
|
||||||
|
read_nb_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if ! check_domain_resolves "$value"; then
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
|
||||||
|
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
|
||||||
|
local confirm=""
|
||||||
|
echo -n "Continue anyway? [y/N]: " > /dev/stderr
|
||||||
|
read -r confirm < /dev/tty
|
||||||
|
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
|
||||||
|
read_nb_domain
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_required() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_secret() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -rs value < /dev/tty
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
# read_yes_no "<prompt>" [<default y|n>]
|
||||||
|
read_yes_no() {
|
||||||
|
local prompt="$1"
|
||||||
|
local default="${2:-n}"
|
||||||
|
local hint
|
||||||
|
if [[ "$default" == "y" ]]; then
|
||||||
|
hint="[Y/n]"
|
||||||
|
else
|
||||||
|
hint="[y/N]"
|
||||||
|
fi
|
||||||
|
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||||
|
local ans=""
|
||||||
|
read -r ans < /dev/tty
|
||||||
|
if [[ -z "$ans" ]]; then
|
||||||
|
ans="$default"
|
||||||
|
fi
|
||||||
|
case "$ans" in
|
||||||
|
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||||
|
*) echo "no" ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_postgres() {
|
||||||
|
set +e
|
||||||
|
echo -n "Waiting for postgres to become ready"
|
||||||
|
local counter=1
|
||||||
|
while true; do
|
||||||
|
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [[ $counter -eq 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Postgres is taking too long. Recent logs:"
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
set -e
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment() {
|
||||||
|
check_openssl
|
||||||
|
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||||
|
|
||||||
|
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
|
||||||
|
echo "Generated files already exist in $(pwd)."
|
||||||
|
echo "If you want to reinitialize the environment, please remove them first:"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
|
||||||
|
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
|
||||||
|
echo "Be aware this will remove all data from the database."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "NetBird Enterprise bootstrap"
|
||||||
|
echo ""
|
||||||
|
echo "Traffic flow:"
|
||||||
|
echo " Enables traffic events logging on the management server."
|
||||||
|
echo " When enabled, the NetBird stack also runs NATS along with two"
|
||||||
|
echo " additional containers: netbird-receiver (the traffic log receiver"
|
||||||
|
echo " service) and netbird-enricher (the traffic log enricher service)."
|
||||||
|
echo " It still has to be turned on from the dashboard settings afterwards."
|
||||||
|
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
|
||||||
|
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
NETBIRD_DOMAIN=$(read_nb_domain)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
|
||||||
|
|
||||||
|
GHCR_USERNAME="netbirdExtAccess1"
|
||||||
|
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
|
||||||
|
|
||||||
|
POSTGRES_USER="netbird"
|
||||||
|
POSTGRES_DB="netbird"
|
||||||
|
POSTGRES_PASSWORD=$(rand_secret)
|
||||||
|
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
|
||||||
|
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
|
||||||
|
|
||||||
|
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
|
||||||
|
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Selected:"
|
||||||
|
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
|
||||||
|
echo " Domain: ${NETBIRD_DOMAIN}"
|
||||||
|
echo ""
|
||||||
|
echo "Rendering files into $(pwd) ..."
|
||||||
|
install -m 600 /dev/null .env
|
||||||
|
render_env >> .env
|
||||||
|
render_docker_compose > docker-compose.yml
|
||||||
|
|
||||||
|
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
|
||||||
|
fi
|
||||||
|
render_caddyfile > Caddyfile
|
||||||
|
install -m 600 /dev/null config.yaml
|
||||||
|
render_config_yaml >> config.yaml
|
||||||
|
|
||||||
|
echo "Logging in to ghcr.io ..."
|
||||||
|
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||||
|
unset GHCR_TOKEN
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Pulling images ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND pull
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting postgres ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||||
|
sleep 2
|
||||||
|
wait_postgres
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting remaining services ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done."
|
||||||
|
echo ""
|
||||||
|
echo "Dashboard: https://${NETBIRD_DOMAIN}"
|
||||||
|
echo ""
|
||||||
|
echo "Open the dashboard in a browser to complete the first-login owner setup."
|
||||||
|
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
|
||||||
|
echo ""
|
||||||
|
echo "Tail logs:"
|
||||||
|
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Renderers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
render_env() {
|
||||||
|
cat <<EOF
|
||||||
|
# Generated by getting-started-enterprise.sh
|
||||||
|
# Holds all configuration and secrets for the stack. Mode 600.
|
||||||
|
|
||||||
|
# Features (set by the script; don't edit without re-running)
|
||||||
|
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||||
|
|
||||||
|
# Domain
|
||||||
|
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
|
||||||
|
|
||||||
|
# Image tags. Default to "latest"
|
||||||
|
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
|
||||||
|
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
|
||||||
|
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
# License keys
|
||||||
|
EOF
|
||||||
|
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
cat <<EOF
|
||||||
|
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
# Postgres
|
||||||
|
POSTGRES_USER=${POSTGRES_USER}
|
||||||
|
POSTGRES_DB=${POSTGRES_DB}
|
||||||
|
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||||
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
|
||||||
|
|
||||||
|
# Relay
|
||||||
|
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
|
||||||
|
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||||
|
|
||||||
|
# Datastore encryption
|
||||||
|
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
|
||||||
|
# Dashboard OIDC scopes
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_docker_compose() {
|
||||||
|
render_compose_header
|
||||||
|
render_compose_common
|
||||||
|
render_compose_server
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
render_compose_flow
|
||||||
|
fi
|
||||||
|
render_compose_postgres
|
||||||
|
render_compose_footer
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_header() {
|
||||||
|
cat <<'EOF'
|
||||||
|
x-default: &default
|
||||||
|
restart: unless-stopped
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: '500m'
|
||||||
|
max-file: '2'
|
||||||
|
|
||||||
|
services:
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_common() {
|
||||||
|
cat <<'EOF'
|
||||||
|
caddy:
|
||||||
|
<<: *default
|
||||||
|
image: caddy:2
|
||||||
|
container_name: netbird-caddy
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
- '443:443/udp'
|
||||||
|
- '80:80'
|
||||||
|
volumes:
|
||||||
|
- netbird_caddy_data:/data
|
||||||
|
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||||
|
|
||||||
|
dashboard:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
|
||||||
|
container_name: netbird-dashboard
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||||
|
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
|
||||||
|
- AUTH_AUDIENCE=netbird-dashboard
|
||||||
|
- AUTH_CLIENT_ID=netbird-dashboard
|
||||||
|
- AUTH_CLIENT_SECRET=
|
||||||
|
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
|
||||||
|
- USE_AUTH0=false
|
||||||
|
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
|
||||||
|
- AUTH_REDIRECT_URI=/nb-auth
|
||||||
|
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||||
|
- NETBIRD_TOKEN_SOURCE=accessToken
|
||||||
|
- NGINX_SSL_PORT=443
|
||||||
|
- LETSENCRYPT_DOMAIN=
|
||||||
|
- LETSENCRYPT_EMAIL=
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_server() {
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird-server:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
|
||||||
|
container_name: netbird-server
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
dashboard:
|
||||||
|
condition: service_started
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
ports:
|
||||||
|
- '3478:3478/udp'
|
||||||
|
volumes:
|
||||||
|
- netbird_data:/var/lib/netbird
|
||||||
|
- ./config.yaml:/etc/netbird/config.yaml
|
||||||
|
command: ["--config", "/etc/netbird/config.yaml"]
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_flow() {
|
||||||
|
cat <<'EOF'
|
||||||
|
nats:
|
||||||
|
<<: *default
|
||||||
|
image: nats:2
|
||||||
|
container_name: netbird-nats
|
||||||
|
networks: [netbird]
|
||||||
|
volumes:
|
||||||
|
- netbird_nats_data:/data
|
||||||
|
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||||
|
|
||||||
|
enricher:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
|
||||||
|
container_name: netbird-enricher
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
volumes:
|
||||||
|
- netbird_enricher:/var/lib/netbird
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
- NB_DATADIR=/var/lib/netbird
|
||||||
|
- NB_MANAGEMENT_STORE_ENGINE=postgres
|
||||||
|
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
|
||||||
|
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
|
||||||
|
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
- NB_FLOW_ADAPTER_TYPE=nats
|
||||||
|
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||||
|
- NB_FLOW_NATS_STREAM=traffic-events
|
||||||
|
- NB_METRICS_PORT=9091
|
||||||
|
- NB_PERSISTENCE_RETENTION_PERIOD=168h
|
||||||
|
|
||||||
|
receiver:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
|
||||||
|
container_name: netbird-receiver
|
||||||
|
networks: [netbird]
|
||||||
|
depends_on:
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
|
||||||
|
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
- NB_FLOW_LISTEN_PORT=80
|
||||||
|
- NB_FLOW_ADAPTER_TYPE=nats
|
||||||
|
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
|
||||||
|
- NB_FLOW_NATS_STREAM=traffic-events
|
||||||
|
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_postgres() {
|
||||||
|
cat <<'EOF'
|
||||||
|
postgres:
|
||||||
|
<<: *default
|
||||||
|
image: postgres:17
|
||||||
|
container_name: netbird-postgres
|
||||||
|
networks: [netbird]
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=${POSTGRES_USER}
|
||||||
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||||
|
- POSTGRES_DB=${POSTGRES_DB}
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 10
|
||||||
|
volumes:
|
||||||
|
- netbird_postgres:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_compose_footer() {
|
||||||
|
cat <<'EOF'
|
||||||
|
volumes:
|
||||||
|
netbird_data:
|
||||||
|
EOF
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird_nats_data:
|
||||||
|
netbird_enricher:
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
cat <<'EOF'
|
||||||
|
netbird_postgres:
|
||||||
|
netbird_caddy_data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_caddyfile() {
|
||||||
|
cat <<'EOF'
|
||||||
|
{
|
||||||
|
servers :80,:443 {
|
||||||
|
protocols h1 h2c h2 h3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(security_headers) {
|
||||||
|
header * {
|
||||||
|
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||||
|
X-Content-Type-Options "nosniff"
|
||||||
|
X-Frame-Options "SAMEORIGIN"
|
||||||
|
X-XSS-Protection "1; mode=block"
|
||||||
|
-Server
|
||||||
|
Referrer-Policy strict-origin-when-cross-origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:80 {
|
||||||
|
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
|
||||||
|
}
|
||||||
|
|
||||||
|
{$CADDY_SECURE_DOMAIN}:443 {
|
||||||
|
import security_headers
|
||||||
|
# Signal (gRPC over h2c)
|
||||||
|
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
|
||||||
|
# Management (gRPC over h2c + HTTP)
|
||||||
|
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
|
||||||
|
reverse_proxy /api/* netbird-server:80
|
||||||
|
reverse_proxy /ws-proxy/* netbird-server:80
|
||||||
|
# Embedded IdP (OAuth2 endpoints served by netbird server)
|
||||||
|
reverse_proxy /oauth2/* netbird-server:80
|
||||||
|
# Relay (WebSocket multiplexed on the same port)
|
||||||
|
reverse_proxy /relay* netbird-server:80
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<'EOF'
|
||||||
|
# Flow receiver (gRPC over h2c)
|
||||||
|
reverse_proxy /flow.FlowService/* h2c://receiver:80
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat <<'EOF'
|
||||||
|
# Dashboard
|
||||||
|
reverse_proxy /* dashboard:80
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
render_config_yaml() {
|
||||||
|
cat <<EOF
|
||||||
|
# NetBird Enterprise server configuration.
|
||||||
|
# Generated by getting-started-enterprise.sh. Mode 600.
|
||||||
|
|
||||||
|
server:
|
||||||
|
listenAddress: ":80"
|
||||||
|
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
|
||||||
|
|
||||||
|
metricsPort: 9090
|
||||||
|
healthcheckAddress: ":9000"
|
||||||
|
|
||||||
|
logLevel: "info"
|
||||||
|
logFile: "console"
|
||||||
|
|
||||||
|
# TLS is terminated by Caddy in front; leave this block empty.
|
||||||
|
tls:
|
||||||
|
certFile: ""
|
||||||
|
keyFile: ""
|
||||||
|
letsencrypt:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
|
||||||
|
dataDir: "/var/lib/netbird/"
|
||||||
|
|
||||||
|
disableAnonymousMetrics: false
|
||||||
|
disableGeoliteUpdate: false
|
||||||
|
|
||||||
|
auth:
|
||||||
|
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
|
||||||
|
localAuthDisabled: false
|
||||||
|
signKeyRefreshEnabled: false
|
||||||
|
dashboardRedirectURIs:
|
||||||
|
- "https://${NETBIRD_DOMAIN}/nb-auth"
|
||||||
|
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
|
||||||
|
cliRedirectURIs:
|
||||||
|
- "http://localhost:53000/"
|
||||||
|
|
||||||
|
store:
|
||||||
|
engine: "postgres"
|
||||||
|
dsn: "${POSTGRES_DSN}"
|
||||||
|
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
|
||||||
|
|
||||||
|
activityStore:
|
||||||
|
engine: "postgres"
|
||||||
|
dsn: "${POSTGRES_DSN}"
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
trafficFlow:
|
||||||
|
enabled: true
|
||||||
|
address: "https://${NETBIRD_DOMAIN}:443"
|
||||||
|
interval: "60s"
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
init_environment
|
||||||
@@ -351,6 +351,11 @@ initialize_default_values() {
|
|||||||
NETBIRD_STUN_PORT=3478
|
NETBIRD_STUN_PORT=3478
|
||||||
|
|
||||||
# Docker images
|
# Docker images
|
||||||
|
# Record whether the operator explicitly pinned the server/proxy images via
|
||||||
|
# env vars, so the agent-network preset can pick its own defaults without
|
||||||
|
# clobbering an explicit override.
|
||||||
|
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||||
|
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||||
# Combined server replaces separate signal, relay, and management containers
|
# Combined server replaces separate signal, relay, and management containers
|
||||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||||
@@ -398,7 +403,53 @@ configure_domain() {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apply_agent_network_preset() {
|
||||||
|
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
|
||||||
|
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
|
||||||
|
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
|
||||||
|
# inputs we still need from the operator are the domain (handled by
|
||||||
|
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
|
||||||
|
# and the ACME email — both honor env vars first and fall back to a
|
||||||
|
# prompt only when unset. CrowdSec is intentionally off.
|
||||||
|
REVERSE_PROXY_TYPE="0"
|
||||||
|
ENABLE_PROXY="true"
|
||||||
|
ENABLE_CROWDSEC="false"
|
||||||
|
|
||||||
|
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||||
|
# env override; otherwise pin the agent-network builds.
|
||||||
|
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||||
|
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
|
||||||
|
fi
|
||||||
|
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||||
|
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||||
|
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||||
|
else
|
||||||
|
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
|
||||||
|
echo " - reverse proxy: built-in Traefik" > /dev/stderr
|
||||||
|
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
|
||||||
|
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
|
||||||
|
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
|
||||||
|
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
|
||||||
|
echo " - CrowdSec: disabled" > /dev/stderr
|
||||||
|
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
}
|
||||||
|
|
||||||
configure_reverse_proxy() {
|
configure_reverse_proxy() {
|
||||||
|
# Short-circuit: agent-network preset locks every reverse-proxy /
|
||||||
|
# proxy / CrowdSec choice and bypasses the interactive prompts.
|
||||||
|
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||||
|
apply_agent_network_preset
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
# Prompt for reverse proxy type
|
# Prompt for reverse proxy type
|
||||||
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
|
||||||
|
|
||||||
@@ -910,6 +961,15 @@ NGINX_SSL_PORT=443
|
|||||||
# Letsencrypt
|
# Letsencrypt
|
||||||
LETSENCRYPT_DOMAIN=none
|
LETSENCRYPT_DOMAIN=none
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
# Agent-network preset: dashboard hides the standard NetBird surfaces
|
||||||
|
# and exposes only the AI Observability + agent-network configuration
|
||||||
|
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
|
||||||
|
NETBIRD_AGENT_NETWORK_ONLY=true
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -946,6 +1006,17 @@ NB_PROXY_PROXY_PROTOCOL=true
|
|||||||
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
# Agent-network preset: turn the proxy into the private reverse-proxy
|
||||||
|
# ingress for agent-network synth services. Disables the public-facing
|
||||||
|
# surface so the proxy serves only synth-generated routes (the
|
||||||
|
# llm_router-driven LLM endpoints) and the per-account inbound
|
||||||
|
# listeners on the embedded netstack.
|
||||||
|
NB_PROXY_PRIVATE=true
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
|
||||||
@@ -1326,12 +1397,20 @@ print_builtin_traefik_instructions() {
|
|||||||
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
echo "This setup is ideal for homelabs and smaller organization deployments."
|
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||||
echo "For enterprise environments requiring high availability and advanced integrations,"
|
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||||
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
echo "consider a commercial on-prem license:"
|
||||||
echo ""
|
echo ""
|
||||||
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
echo " Commercial license: https://netbird.ai/pricing"
|
||||||
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
echo " Documentation: https://docs.netbird.io/agent-network"
|
||||||
|
else
|
||||||
|
echo "This setup is ideal for homelabs and smaller organization deployments."
|
||||||
|
echo "For enterprise environments requiring high availability and advanced integrations,"
|
||||||
|
echo "consider a commercial on-prem license or scaling your open source deployment:"
|
||||||
|
echo ""
|
||||||
|
echo " Commercial license: https://netbird.io/pricing#on-prem"
|
||||||
|
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
|
||||||
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
if [[ "$ENABLE_PROXY" == "true" ]]; then
|
||||||
echo "NetBird Proxy:"
|
echo "NetBird Proxy:"
|
||||||
@@ -1354,6 +1433,11 @@ print_builtin_traefik_instructions() {
|
|||||||
echo ""
|
echo ""
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
|
||||||
|
echo "Note: The public domain is only for setting up secure connections."
|
||||||
|
echo "Your APIs and agent services remain private and are never exposed publicly."
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
638
infrastructure_files/migrate-to-enterprise.sh
Executable file
@@ -0,0 +1,638 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# NetBird — community combined → Enterprise combined migration
|
||||||
|
#
|
||||||
|
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
|
||||||
|
# by docker compose) and config.yaml.enterprise alongside the operator's
|
||||||
|
# existing files. Original docker-compose.yml and config.yaml are never
|
||||||
|
# modified.
|
||||||
|
#
|
||||||
|
# Steps (all optional, asked interactively):
|
||||||
|
# 1. Image swap — replace community images with enterprise cloud images.
|
||||||
|
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
|
||||||
|
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
|
||||||
|
#
|
||||||
|
# To revert:
|
||||||
|
# docker compose down
|
||||||
|
# rm -f docker-compose.override.yml config.yaml.enterprise
|
||||||
|
# # If Postgres migration was done, also restore the SQLite backup printed
|
||||||
|
# # at the end of this script's run.
|
||||||
|
# docker compose up -d
|
||||||
|
|
||||||
|
OVERRIDE_FILE="docker-compose.override.yml"
|
||||||
|
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||||
|
|
||||||
|
check_docker_compose() {
|
||||||
|
if command -v docker-compose &> /dev/null; then
|
||||||
|
echo "docker-compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if docker compose --help &> /dev/null; then
|
||||||
|
echo "docker compose"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "docker-compose is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
check_yq() {
|
||||||
|
if ! command -v yq &> /dev/null; then
|
||||||
|
cat > /dev/stderr <<'EOF'
|
||||||
|
yq is required to parse and update YAML safely.
|
||||||
|
|
||||||
|
macOS: brew install yq
|
||||||
|
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
|
||||||
|
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if ! yq --version 2>&1 | grep -q "mikefarah"; then
|
||||||
|
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
check_openssl() {
|
||||||
|
if ! command -v openssl &> /dev/null; then
|
||||||
|
echo "openssl is not installed or not in PATH." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
rand_password() {
|
||||||
|
openssl rand -hex 32
|
||||||
|
}
|
||||||
|
|
||||||
|
read_required() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -r value < /dev/tty
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_secret() {
|
||||||
|
local prompt="$1"
|
||||||
|
local value=""
|
||||||
|
while [[ -z "$value" ]]; do
|
||||||
|
echo -n "$prompt: " > /dev/stderr
|
||||||
|
read -rs value < /dev/tty
|
||||||
|
echo "" > /dev/stderr
|
||||||
|
if [[ -z "$value" ]]; then
|
||||||
|
echo "Value cannot be empty." > /dev/stderr
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "$value"
|
||||||
|
}
|
||||||
|
|
||||||
|
read_yes_no() {
|
||||||
|
local prompt="$1"
|
||||||
|
local default="${2:-n}"
|
||||||
|
local hint
|
||||||
|
if [[ "$default" == "y" ]]; then
|
||||||
|
hint="[Y/n]"
|
||||||
|
else
|
||||||
|
hint="[y/N]"
|
||||||
|
fi
|
||||||
|
echo -n "${prompt} ${hint}: " > /dev/stderr
|
||||||
|
local ans=""
|
||||||
|
read -r ans < /dev/tty
|
||||||
|
if [[ -z "$ans" ]]; then
|
||||||
|
ans="$default"
|
||||||
|
fi
|
||||||
|
case "$ans" in
|
||||||
|
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
|
||||||
|
*) echo "no" ;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Detection — read the operator's existing compose to find service names and
|
||||||
|
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
detect_combined_service() {
|
||||||
|
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_dashboard_service() {
|
||||||
|
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_config_yaml_host_path() {
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_data_volume() {
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_exposed_address() {
|
||||||
|
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
|
||||||
|
}
|
||||||
|
|
||||||
|
detect_compose_network() {
|
||||||
|
local tag
|
||||||
|
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
|
||||||
|
case "$tag" in
|
||||||
|
"!!seq")
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
|
||||||
|
;;
|
||||||
|
"!!map")
|
||||||
|
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "default"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Renderers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Build docker-compose.override.yml from the steps the operator selected.
|
||||||
|
# Service names match what we detected on the operator's side.
|
||||||
|
render_override() {
|
||||||
|
cat <<EOF
|
||||||
|
# Generated by migrate-to-enterprise.sh. Mode 644.
|
||||||
|
# Merged with docker-compose.yml automatically by Docker Compose.
|
||||||
|
# Remove this file (and config.yaml.enterprise if present) to revert.
|
||||||
|
|
||||||
|
services:
|
||||||
|
${DASHBOARD_SERVICE}:
|
||||||
|
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
|
||||||
|
|
||||||
|
${COMBINED_SERVICE}:
|
||||||
|
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
volumes:
|
||||||
|
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
|
||||||
|
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
|
||||||
|
|
||||||
|
postgres:
|
||||||
|
image: postgres:17
|
||||||
|
container_name: netbird-postgres
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: netbird
|
||||||
|
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
|
||||||
|
POSTGRES_DB: netbird
|
||||||
|
volumes:
|
||||||
|
- netbird_postgres:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 20
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
nats:
|
||||||
|
image: nats:2
|
||||||
|
container_name: netbird-nats
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
|
||||||
|
volumes:
|
||||||
|
- netbird_nats_data:/data
|
||||||
|
|
||||||
|
flow-enricher:
|
||||||
|
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
|
||||||
|
container_name: netbird-flow-enricher
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
NB_DATADIR: /var/lib/netbird
|
||||||
|
NB_MANAGEMENT_STORE_ENGINE: postgres
|
||||||
|
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
|
||||||
|
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
|
||||||
|
NB_FLOW_ADAPTER_TYPE: nats
|
||||||
|
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||||
|
NB_FLOW_NATS_STREAM: traffic-events
|
||||||
|
NB_METRICS_PORT: 9091
|
||||||
|
NB_PERSISTENCE_RETENTION_PERIOD: 168h
|
||||||
|
|
||||||
|
flow-receiver:
|
||||||
|
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
|
||||||
|
container_name: netbird-flow-receiver
|
||||||
|
restart: unless-stopped
|
||||||
|
networks: [${COMPOSE_NETWORK}]
|
||||||
|
depends_on:
|
||||||
|
nats:
|
||||||
|
condition: service_started
|
||||||
|
environment:
|
||||||
|
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
|
||||||
|
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
|
||||||
|
NB_FLOW_LISTEN_PORT: 80
|
||||||
|
NB_FLOW_ADAPTER_TYPE: nats
|
||||||
|
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
|
||||||
|
NB_FLOW_NATS_STREAM: traffic-events
|
||||||
|
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
|
||||||
|
- traefik.http.routers.netbird-flow.entrypoints=websecure
|
||||||
|
- traefik.http.routers.netbird-flow.tls=true
|
||||||
|
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
|
||||||
|
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
|
||||||
|
- traefik.http.routers.netbird-flow.priority=100
|
||||||
|
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
|
||||||
|
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Volume declarations for anything new the override introduced
|
||||||
|
local has_volumes="no"
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
has_volumes="yes"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$has_volumes" == "yes" ]]; then
|
||||||
|
cat <<EOF
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
EOF
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo " netbird_postgres:"
|
||||||
|
fi
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
echo " netbird_nats_data:"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build config.yaml.enterprise by yq-editing the operator's existing
|
||||||
|
# config.yaml. We don't touch the original file.
|
||||||
|
render_enterprise_config() {
|
||||||
|
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
|
||||||
|
|
||||||
|
yq eval "
|
||||||
|
.server.store.engine = \"postgres\" |
|
||||||
|
.server.store.dsn = \"$pg_dsn\" |
|
||||||
|
.server.activityStore.engine = \"postgres\" |
|
||||||
|
.server.activityStore.dsn = \"$pg_dsn\" |
|
||||||
|
.server.authStore.engine = \"postgres\" |
|
||||||
|
.server.authStore.dsn = \"$pg_dsn\"
|
||||||
|
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
local flow_addr="${NETBIRD_DOMAIN}"
|
||||||
|
yq eval -i "
|
||||||
|
.server.trafficFlow.enabled = true |
|
||||||
|
.server.trafficFlow.address = \"$flow_addr\" |
|
||||||
|
.server.trafficFlow.interval = \"60s\"
|
||||||
|
" "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Execution steps
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
resolve_data_volume() {
|
||||||
|
local short="$1"
|
||||||
|
local actual
|
||||||
|
# Resolve project-prefixed volume name from Docker Compose config first.
|
||||||
|
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
|
||||||
|
if [[ -n "$actual" && "$actual" != "null" ]]; then
|
||||||
|
echo "$actual"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
# Relative bind mount: docker-compose resolves it against the compose
|
||||||
|
# file's directory, but `docker run -v` resolves it against the current
|
||||||
|
# working directory. Normalize to an absolute path so both interpretations
|
||||||
|
# agree (and the printed revert command works from any CWD).
|
||||||
|
if [[ "$short" == ./* || "$short" == ../* ]]; then
|
||||||
|
local compose_dir
|
||||||
|
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
|
||||||
|
(
|
||||||
|
cd "$compose_dir"
|
||||||
|
cd "$(dirname "$short")"
|
||||||
|
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
|
||||||
|
echo "$short"
|
||||||
|
}
|
||||||
|
|
||||||
|
backup_sqlite() {
|
||||||
|
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
|
||||||
|
mkdir -p "$BACKUP_DIR"
|
||||||
|
local data_volume_actual
|
||||||
|
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||||
|
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
|
||||||
|
docker run --rm \
|
||||||
|
-v "${data_volume_actual}:/var/lib/netbird:ro" \
|
||||||
|
-v "${BACKUP_DIR}:/backup" \
|
||||||
|
busybox \
|
||||||
|
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
|
||||||
|
local copied
|
||||||
|
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
|
||||||
|
if [[ -z "$copied" ]]; then
|
||||||
|
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo " done"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_migrate_store() {
|
||||||
|
echo "Running migrate-store (SQLite → Postgres) ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
|
||||||
|
echo " done"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
init_migration() {
|
||||||
|
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||||
|
check_yq
|
||||||
|
check_openssl
|
||||||
|
|
||||||
|
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||||
|
|
||||||
|
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||||
|
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
|
||||||
|
echo "Migration artifacts already exist in $(pwd):"
|
||||||
|
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
|
||||||
|
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "Either you've already migrated, or a previous run was interrupted."
|
||||||
|
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
COMBINED_SERVICE=$(detect_combined_service)
|
||||||
|
DASHBOARD_SERVICE=$(detect_dashboard_service)
|
||||||
|
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
|
||||||
|
DATA_VOLUME=$(detect_data_volume)
|
||||||
|
COMPOSE_NETWORK=$(detect_compose_network)
|
||||||
|
|
||||||
|
if [[ -z "$COMBINED_SERVICE" ]]; then
|
||||||
|
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
|
||||||
|
echo "This script targets the community combined-server deployment." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$DASHBOARD_SERVICE" ]]; then
|
||||||
|
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$CONFIG_YAML_HOST" ]]; then
|
||||||
|
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
|
||||||
|
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$DATA_VOLUME" ]]; then
|
||||||
|
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Detected existing deployment:"
|
||||||
|
echo " Combined service: $COMBINED_SERVICE"
|
||||||
|
echo " Dashboard: $DASHBOARD_SERVICE"
|
||||||
|
echo " config.yaml: $CONFIG_YAML_HOST"
|
||||||
|
echo " Data volume: $DATA_VOLUME"
|
||||||
|
echo " Network: $COMPOSE_NETWORK"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
local proceed
|
||||||
|
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||||
|
if [[ "$proceed" != "yes" ]]; then
|
||||||
|
echo "Aborted."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 1 — always (this is the point of the script)
|
||||||
|
MIGRATE_IMAGES="yes"
|
||||||
|
echo ""
|
||||||
|
echo "Step 1: Image swap (community → Enterprise). License key required."
|
||||||
|
NB_LICENSE_KEY=$(read_secret " License key")
|
||||||
|
GHCR_USERNAME="netbirdExtAccess1"
|
||||||
|
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
|
||||||
|
|
||||||
|
# Step 2 — optional
|
||||||
|
echo ""
|
||||||
|
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
|
||||||
|
echo " will be backed up automatically. To fully revert later, restore"
|
||||||
|
echo " that backup and delete docker-compose.override.yml +"
|
||||||
|
echo " config.yaml.enterprise."
|
||||||
|
local confirm
|
||||||
|
confirm=$(read_yes_no " Continue?" "y")
|
||||||
|
if [[ "$confirm" != "yes" ]]; then
|
||||||
|
MIGRATE_POSTGRES="no"
|
||||||
|
echo " Skipping Postgres migration."
|
||||||
|
else
|
||||||
|
POSTGRES_PASSWORD=$(rand_password)
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
|
||||||
|
echo ""
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
# Auth secret MUST match server.authSecret from config.yaml
|
||||||
|
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
|
||||||
|
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
|
||||||
|
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
NETBIRD_DOMAIN=$(detect_exposed_address)
|
||||||
|
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
|
||||||
|
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
|
||||||
|
fi
|
||||||
|
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
|
||||||
|
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
|
||||||
|
|
||||||
|
# We need the encryption key from the existing config.yaml for the enricher
|
||||||
|
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
|
||||||
|
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
|
||||||
|
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
ENABLE_FLOW="no"
|
||||||
|
echo "Step 3 (traffic flow) skipped — requires Postgres."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
apply_changes() {
|
||||||
|
echo ""
|
||||||
|
echo "Writing $OVERRIDE_FILE ..."
|
||||||
|
install -m 644 /dev/null "$OVERRIDE_FILE"
|
||||||
|
render_override > "$OVERRIDE_FILE"
|
||||||
|
|
||||||
|
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
|
||||||
|
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
|
||||||
|
render_enterprise_config
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Persist secrets that the override file references via env interpolation.
|
||||||
|
# We write them to a .env file in the current directory; docker compose
|
||||||
|
# picks it up automatically.
|
||||||
|
echo "Writing .env additions (mode 600) ..."
|
||||||
|
local ENV_FILE=".env"
|
||||||
|
touch "$ENV_FILE"
|
||||||
|
chmod 600 "$ENV_FILE"
|
||||||
|
{
|
||||||
|
echo ""
|
||||||
|
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||||
|
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||||
|
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||||
|
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||||
|
fi
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
|
||||||
|
fi
|
||||||
|
if [[ "$ENABLE_FLOW" == "yes" ]]; then
|
||||||
|
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
|
||||||
|
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
|
||||||
|
fi
|
||||||
|
} >> "$ENV_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Logging in to ghcr.io ..."
|
||||||
|
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
|
||||||
|
unset GHCR_TOKEN
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Pulling enterprise images ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND pull
|
||||||
|
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Stopping existing services (volumes preserved) ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND down
|
||||||
|
|
||||||
|
backup_sqlite
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Starting Postgres ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d postgres
|
||||||
|
|
||||||
|
# Wait for healthy
|
||||||
|
local counter=0
|
||||||
|
echo -n "Waiting for Postgres to become ready"
|
||||||
|
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
|
||||||
|
echo -n " ."
|
||||||
|
sleep 2
|
||||||
|
counter=$((counter + 1))
|
||||||
|
if [[ $counter -ge 60 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Postgres did not become ready in 120s. Recent logs:"
|
||||||
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo " done"
|
||||||
|
|
||||||
|
run_migrate_store
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Bringing up all services ..."
|
||||||
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Migration complete."
|
||||||
|
}
|
||||||
|
|
||||||
|
print_summary() {
|
||||||
|
echo ""
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " Summary"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " Images: swapped to enterprise"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
|
||||||
|
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
|
||||||
|
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
|
||||||
|
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
|
||||||
|
echo ""
|
||||||
|
echo " Generated files (next to your docker-compose.yml):"
|
||||||
|
echo " $OVERRIDE_FILE"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo " .env (license key + secrets, mode 600)"
|
||||||
|
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
|
||||||
|
echo ""
|
||||||
|
echo " Tail logs:"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
|
||||||
|
echo ""
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " To revert"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND down"
|
||||||
|
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
|
||||||
|
# Resolve project-prefixed volume names now (before override is removed).
|
||||||
|
local pg_volume data_volume_actual
|
||||||
|
pg_volume=$(resolve_data_volume "netbird_postgres")
|
||||||
|
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
|
||||||
|
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
|
||||||
|
echo " docker volume rm $pg_volume"
|
||||||
|
echo " # Restore SQLite from the backup created during this run:"
|
||||||
|
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
|
||||||
|
fi
|
||||||
|
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
|
||||||
|
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
|
||||||
|
echo " $DOCKER_COMPOSE_COMMAND up -d"
|
||||||
|
echo "──────────────────────────────────────────────────────────────────────"
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
init_migration
|
||||||
|
apply_changes
|
||||||
|
print_summary
|
||||||
@@ -497,7 +497,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
|||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||||
|
|
||||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||||
peerIDs: make(map[string]struct{}),
|
peerIDs: make(map[string]struct{}),
|
||||||
@@ -610,12 +610,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return nil, nil, 0, err
|
return nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
startPosture := time.Now()
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, 0, err
|
return nil, nil, 0, err
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
|
||||||
|
|
||||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetServiceManager(s.ServiceManager())
|
proxyService.SetServiceManager(s.ServiceManager())
|
||||||
proxyService.SetProxyController(s.ServiceProxyController())
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ const (
|
|||||||
reconnThreshold = 5 * time.Minute
|
reconnThreshold = 5 * time.Minute
|
||||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||||
)
|
)
|
||||||
|
|
||||||
type lfConfig struct {
|
type lfConfig struct {
|
||||||
@@ -139,7 +139,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
|||||||
state.lastSeen = now
|
state.lastSeen = now
|
||||||
}
|
}
|
||||||
|
|
||||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||||
h := fnv.New64a()
|
h := fnv.New64a()
|
||||||
|
|
||||||
h.Write([]byte(meta.WtVersion))
|
h.Write([]byte(meta.WtVersion))
|
||||||
@@ -147,14 +147,6 @@ func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
|||||||
h.Write([]byte(meta.KernelVersion))
|
h.Write([]byte(meta.KernelVersion))
|
||||||
h.Write([]byte(meta.Hostname))
|
h.Write([]byte(meta.Hostname))
|
||||||
h.Write([]byte(meta.SystemSerialNumber))
|
h.Write([]byte(meta.SystemSerialNumber))
|
||||||
h.Write([]byte(pubip))
|
|
||||||
|
|
||||||
macs := uint64(0)
|
return h.Sum64()
|
||||||
for _, na := range meta.NetworkAddresses {
|
|
||||||
for _, r := range na.Mac {
|
|
||||||
macs += uint64(r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.Sum64() + macs
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,9 +164,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
|||||||
KernelVersion: "5.15.0-76-generic",
|
KernelVersion: "5.15.0-76-generic",
|
||||||
Hostname: "prod-server-database-01",
|
Hostname: "prod-server-database-01",
|
||||||
SystemSerialNumber: "PC-1234567890",
|
SystemSerialNumber: "PC-1234567890",
|
||||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
|
||||||
}
|
}
|
||||||
pubip := "8.8.8.8"
|
|
||||||
|
|
||||||
var resultString string
|
var resultString string
|
||||||
var resultUint uint64
|
var resultUint uint64
|
||||||
@@ -175,7 +173,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
resultString = builderString(meta, pubip)
|
resultString = builderString(meta)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -183,7 +181,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
resultString = fnvHashToString(meta, pubip)
|
resultString = fnvHashToString(meta)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -191,7 +189,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
resultUint = metaHash(meta, pubip)
|
resultUint = metaHash(meta)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -199,29 +197,20 @@ func BenchmarkHashingMethods(b *testing.B) {
|
|||||||
_ = resultUint
|
_ = resultUint
|
||||||
}
|
}
|
||||||
|
|
||||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||||
h := fnv.New64a()
|
h := fnv.New64a()
|
||||||
|
|
||||||
if len(meta.NetworkAddresses) != 0 {
|
|
||||||
for _, na := range meta.NetworkAddresses {
|
|
||||||
h.Write([]byte(na.Mac))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Write([]byte(meta.WtVersion))
|
h.Write([]byte(meta.WtVersion))
|
||||||
h.Write([]byte(meta.OSVersion))
|
h.Write([]byte(meta.OSVersion))
|
||||||
h.Write([]byte(meta.KernelVersion))
|
h.Write([]byte(meta.KernelVersion))
|
||||||
h.Write([]byte(meta.Hostname))
|
h.Write([]byte(meta.Hostname))
|
||||||
h.Write([]byte(meta.SystemSerialNumber))
|
h.Write([]byte(meta.SystemSerialNumber))
|
||||||
h.Write([]byte(pubip))
|
|
||||||
|
|
||||||
return strconv.FormatUint(h.Sum64(), 16)
|
return strconv.FormatUint(h.Sum64(), 16)
|
||||||
}
|
}
|
||||||
|
|
||||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||||
mac := getMacAddress(meta.NetworkAddresses)
|
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
|
||||||
len(pubip) + len(mac) + 6
|
|
||||||
|
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.Grow(estimatedSize)
|
b.Grow(estimatedSize)
|
||||||
@@ -235,23 +224,10 @@ func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
|||||||
b.WriteString(meta.Hostname)
|
b.WriteString(meta.Hostname)
|
||||||
b.WriteByte('|')
|
b.WriteByte('|')
|
||||||
b.WriteString(meta.SystemSerialNumber)
|
b.WriteString(meta.SystemSerialNumber)
|
||||||
b.WriteByte('|')
|
|
||||||
b.WriteString(pubip)
|
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
|
||||||
if len(nas) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
macs := make([]string, 0, len(nas))
|
|
||||||
for _, na := range nas {
|
|
||||||
macs = append(macs, na.Mac)
|
|
||||||
}
|
|
||||||
return strings.Join(macs, "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||||
numKeys := 100000
|
numKeys := 100000
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||||
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
|
|||||||
// Manager for users
|
// Manager for users
|
||||||
usersManager users.Manager
|
usersManager users.Manager
|
||||||
|
|
||||||
|
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||||
|
idpManager idp.Manager
|
||||||
|
|
||||||
// Store for one-time authentication tokens
|
// Store for one-time authentication tokens
|
||||||
tokenStore *OneTimeTokenStore
|
tokenStore *OneTimeTokenStore
|
||||||
|
|
||||||
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyServiceServer creates a new proxy service server.
|
// NewProxyServiceServer creates a new proxy service server.
|
||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
|||||||
pkceVerifierStore: pkceStore,
|
pkceVerifierStore: pkceStore,
|
||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
|
idpManager: idpManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
tokenChecker: tokenChecker,
|
tokenChecker: tokenChecker,
|
||||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||||
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||||
|
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||||
// Resolve the principal: when the peer is linked to a user, the human
|
|
||||||
// is the principal so multiple peers owned by the same user share a
|
|
||||||
// single identity. Unlinked peers (machine agents) are their own
|
|
||||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
|
||||||
// tag spend with — user.Email when linked, peer.Name when not.
|
|
||||||
principalID := peer.ID
|
|
||||||
displayIdentity := peer.Name
|
|
||||||
if peer.UserID != "" {
|
|
||||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
|
||||||
principalID = user.Id
|
|
||||||
if user.Email != "" {
|
|
||||||
displayIdentity = user.Email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||||
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||||
|
// user or peer ID, and peer name or user email.
|
||||||
|
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||||
|
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||||
|
// principal so multiple peers owned by the same user share a single
|
||||||
|
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||||
|
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||||
|
// user.Email when linked, peer.Name when not.
|
||||||
|
|
||||||
|
// If the peer isn't associated with a user, return the peer info directly.
|
||||||
|
if peer.UserID == "" {
|
||||||
|
return peer.ID, peer.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||||
|
// if an IdP is available, we gather details on the user from it.
|
||||||
|
principalID := peer.UserID
|
||||||
|
displayIdentity := peer.Name
|
||||||
|
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||||
|
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||||
|
principalID = user.Id
|
||||||
|
if user.Email != "" {
|
||||||
|
displayIdentity = user.Email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// IdP enrichment wins when available — the stored email column is a
|
||||||
|
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||||
|
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||||
|
if s.idpManager != nil {
|
||||||
|
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||||
|
displayIdentity = ud.Email
|
||||||
|
} else if uerr != nil {
|
||||||
|
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return principalID, displayIdentity
|
||||||
|
}
|
||||||
|
|
||||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||||
// groups. Private services authorise against AccessGroups (empty list fails
|
// groups. Private services authorise against AccessGroups (empty list fails
|
||||||
// closed — Validate() rejects that at save time but the RPC is the security
|
// closed — Validate() rejects that at save time but the RPC is the security
|
||||||
|
|||||||
@@ -3,14 +3,19 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockReverseProxyManager struct {
|
type mockReverseProxyManager struct {
|
||||||
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
|||||||
return user, nil, nil
|
return user, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||||
|
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||||
|
// panics if any unexpected method is invoked).
|
||||||
|
type mockTunnelPeersManager struct {
|
||||||
|
peers.Manager
|
||||||
|
peer *peer.Peer
|
||||||
|
peerErr error
|
||||||
|
groups []*types.Group
|
||||||
|
groupsErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||||
|
return m.peer, m.peerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||||
|
return m.peer, m.groups, m.groupsErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||||
|
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||||
|
// an IdP that knows nothing about the user.
|
||||||
|
type mockTunnelIdpManager struct {
|
||||||
|
idp.Manager
|
||||||
|
email string
|
||||||
|
hasData bool
|
||||||
|
err error
|
||||||
|
gotCalls int
|
||||||
|
gotMeta []idp.AppMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||||
|
m.gotCalls++
|
||||||
|
m.gotMeta = append(m.gotMeta, meta)
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
if !m.hasData {
|
||||||
|
// This might not be a thing any of the actual IDP implementations do,
|
||||||
|
// i.e. return a nil value with no error, but it seems valuable to test
|
||||||
|
// that behavior here.
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateUserGroupAccess(t *testing.T) {
|
func TestValidateUserGroupAccess(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||||
|
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||||
|
// (IdP email -> stored User.Email -> peer.Name).
|
||||||
|
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||||
|
const (
|
||||||
|
domain = "app.example.com"
|
||||||
|
accountID = "account1"
|
||||||
|
peerID = "peer1"
|
||||||
|
peerName = "peer-display-name"
|
||||||
|
userID = "user1"
|
||||||
|
)
|
||||||
|
|
||||||
|
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||||
|
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerUserID string
|
||||||
|
storedUsers map[string]*types.User
|
||||||
|
storedErr error
|
||||||
|
noIdP bool
|
||||||
|
idpEmail string
|
||||||
|
idpHasData bool
|
||||||
|
idpErr error
|
||||||
|
expectEmail string
|
||||||
|
expectUserID string
|
||||||
|
expectIdPHit bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "idp email wins over stored email",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp returns empty email",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp has no data",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpHasData: false,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when idp errors",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpErr: errors.New("idp unreachable"),
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stored email when no idp manager",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUser,
|
||||||
|
noIdP: true,
|
||||||
|
expectEmail: "stored@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "idp email when stored email is empty",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUserNoEmail,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: map[string]*types.User{},
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: "idp@example.com",
|
||||||
|
expectUserID: userID,
|
||||||
|
expectIdPHit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unlinked peer uses peer name and never consults idp",
|
||||||
|
peerUserID: "",
|
||||||
|
storedUsers: storedUser,
|
||||||
|
idpEmail: "idp@example.com",
|
||||||
|
idpHasData: true,
|
||||||
|
expectEmail: peerName,
|
||||||
|
expectUserID: peerID,
|
||||||
|
expectIdPHit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||||
|
peerUserID: userID,
|
||||||
|
storedUsers: storedUserNoEmail,
|
||||||
|
noIdP: true,
|
||||||
|
expectEmail: peerName,
|
||||||
|
expectUserID: userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||||
|
server := &ProxyServiceServer{
|
||||||
|
serviceManager: &mockReverseProxyManager{
|
||||||
|
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||||
|
},
|
||||||
|
peersManager: &mockTunnelPeersManager{
|
||||||
|
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||||
|
},
|
||||||
|
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||||
|
}
|
||||||
|
|
||||||
|
var idpMock *mockTunnelIdpManager
|
||||||
|
if !tt.noIdP {
|
||||||
|
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||||
|
server.idpManager = idpMock
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||||
|
Domain: domain,
|
||||||
|
TunnelIp: "100.64.0.1",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.True(t, resp.GetValid(), "expected access granted")
|
||||||
|
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||||
|
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||||
|
|
||||||
|
if idpMock != nil {
|
||||||
|
if tt.expectIdPHit {
|
||||||
|
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||||
|
require.Len(t, idpMock.gotMeta, 1)
|
||||||
|
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user