mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-29 11:19:56 +00:00
Compare commits
129 Commits
refactor/m
...
embedded-v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd7bf982c3 | ||
|
|
d568084d61 | ||
|
|
2d7b309004 | ||
|
|
5968cff242 | ||
|
|
cf43841b86 | ||
|
|
739e36a313 | ||
|
|
2bb5421631 | ||
|
|
998ade6e6d | ||
|
|
62f5467cd8 | ||
|
|
1b29995ece | ||
|
|
fd96b8c12f | ||
|
|
6dd6c3f398 | ||
|
|
d1422dcf09 | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c | ||
|
|
17b2044596 | ||
|
|
07101c59ac | ||
|
|
51b6f6291b | ||
|
|
2ebf26006a | ||
|
|
211a26019a | ||
|
|
6c26178ad5 | ||
|
|
af3b7e4497 | ||
|
|
e84f6527f7 | ||
|
|
ac9529ea8c | ||
|
|
f736ef9647 | ||
|
|
cf58bf1ba9 | ||
|
|
522b8ed969 | ||
|
|
c9e99659ea | ||
|
|
c1eecaac26 | ||
|
|
f2c79201b3 | ||
|
|
2fdc3aea4c | ||
|
|
144dfbc12c | ||
|
|
6c9465df54 | ||
|
|
6cd5d6084f | ||
|
|
3bcacffd2c | ||
|
|
65f302b698 | ||
|
|
2f67841b1e | ||
|
|
bf2fb2fd44 | ||
|
|
4e3e3ce6d3 | ||
|
|
5e2830be8a | ||
|
|
f557e665a5 | ||
|
|
fa57eedaf5 | ||
|
|
7cb6388349 | ||
|
|
1f912be673 | ||
|
|
8d329da591 | ||
|
|
8e72967bbe | ||
|
|
c29ef638f4 | ||
|
|
97b7b010f5 | ||
|
|
030c57150f | ||
|
|
0f03c612d1 | ||
|
|
1cc5967198 | ||
|
|
412193c602 | ||
|
|
5e67febf57 | ||
|
|
ee348ba007 | ||
|
|
3d3055dc7f | ||
|
|
2f4ddf0796 | ||
|
|
98d533c8e8 | ||
|
|
ef4ea2e311 | ||
|
|
b41d11bbbe | ||
|
|
f37e228cc2 | ||
|
|
640a267556 | ||
|
|
17359cdc1e | ||
|
|
7e5846a1ee | ||
|
|
517bea0daf | ||
|
|
896530fd82 | ||
|
|
354fd004c7 | ||
|
|
c28e41e82b | ||
|
|
02b9fe704b | ||
|
|
5e200fa571 | ||
|
|
7d61975f6c | ||
|
|
62b36112ea | ||
|
|
df9a6fb020 | ||
|
|
b1b04f9ec6 | ||
|
|
fe15688f20 | ||
|
|
2285db2b62 | ||
|
|
b3f0f53a23 | ||
|
|
5eec9962ba | ||
|
|
393c102f45 | ||
|
|
b41fbad5e1 | ||
|
|
24a5f2252c | ||
|
|
9d189bb3e8 | ||
|
|
8e2505b59c | ||
|
|
97bc1eebde | ||
|
|
32a5a061b8 | ||
|
|
d927ef468a | ||
|
|
d3f3e08035 | ||
|
|
6bb66e0fad | ||
|
|
bc407527f4 | ||
|
|
5543404188 | ||
|
|
c2fdf62f1f | ||
|
|
b9f5264e36 | ||
|
|
97d0a6776f | ||
|
|
7e7e056f3a | ||
|
|
785f94d13f | ||
|
|
bfb6750b13 | ||
|
|
f5e1057127 | ||
|
|
ee393d0e62 | ||
|
|
0b8fc5da59 | ||
|
|
2d0a54f31a | ||
|
|
61ec8d67de | ||
|
|
76add0b9b2 | ||
|
|
a11341f57a | ||
|
|
b135d462d6 | ||
|
|
da37a28951 | ||
|
|
4f884d9f30 | ||
|
|
2bed8b641b | ||
|
|
b4f696272a | ||
|
|
6d937af7a0 | ||
|
|
db5b6cfbb7 | ||
|
|
e75948753a | ||
|
|
047cc958b5 | ||
|
|
cd005ef9a9 | ||
|
|
44ed0c1992 | ||
|
|
d6d3fa95c7 | ||
|
|
fa90283781 | ||
|
|
8bf13b0d0c | ||
|
|
a8541a1529 | ||
|
|
94068d3ebc | ||
|
|
738c585ee7 | ||
|
|
9b5541d17d | ||
|
|
7123e6d1f4 | ||
|
|
62cf9e873b | ||
|
|
9f0aa1ce26 |
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -59,12 +59,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/git-town.yml
vendored
2
.github/workflows/git-town.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
|
||||
10
.github/workflows/golang-test-darwin.yml
vendored
10
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,18 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -45,10 +45,10 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
24
.github/workflows/golang-test-freebsd.yml
vendored
24
.github/workflows/golang-test-freebsd.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
id: test
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -48,14 +48,14 @@ jobs:
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./dns/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./encryption/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./formatter/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./route/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./util/...
|
||||
time go test -tags privileged -timeout 1m -failfast ./version/...
|
||||
|
||||
82
.github/workflows/golang-test-linux.yml
vendored
82
.github/workflows/golang-test-linux.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -119,12 +119,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,11 +158,11 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -175,12 +175,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -246,12 +246,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -266,7 +266,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -306,12 +306,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -325,7 +325,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -363,12 +363,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -383,7 +383,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -407,7 +407,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -424,12 +424,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -440,7 +440,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -484,7 +484,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
@@ -529,12 +529,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -545,7 +545,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -579,10 +579,11 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -623,12 +624,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -639,7 +640,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -673,12 +674,13 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
env:
|
||||
GIT_BRANCH: ${{ github.ref_name }}
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
@@ -692,12 +694,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -708,7 +710,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -734,7 +736,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
|
||||
8
.github/workflows/golang-test-windows.yml
vendored
8
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
run: |
|
||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||
|
||||
- name: test
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
12
.github/workflows/mobile-build-validation.yml
vendored
12
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,11 +16,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
@@ -28,13 +28,13 @@ jobs:
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -54,11 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
38
.github/workflows/release.yml
vendored
38
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -166,12 +166,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -186,9 +186,9 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
@@ -221,7 +221,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
@@ -374,12 +374,12 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -464,17 +464,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -522,7 +522,7 @@ jobs:
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -534,13 +534,13 @@ jobs:
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
|
||||
14
.github/workflows/test-infrastructure-files.yml
vendored
14
.github/workflows/test-infrastructure-files.yml
vendored
@@ -68,17 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
@@ -225,7 +225,7 @@ jobs:
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest .
|
||||
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,11 +19,11 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
@@ -44,11 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
|
||||
@@ -462,9 +462,13 @@ checksum:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./infrastructure_files/getting-started.sh
|
||||
- glob: ./infrastructure_files/getting-started-enterprise.sh
|
||||
- glob: ./infrastructure_files/migrate-to-enterprise.sh
|
||||
|
||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
@@ -25,3 +25,15 @@ setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
|
||||
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||
test-unit:
|
||||
@go test -tags devcert -timeout 10m ./...
|
||||
|
||||
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||
# Narrow the run with env vars, e.g.:
|
||||
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||
test-privileged:
|
||||
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||
|
||||
@@ -37,6 +37,11 @@
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
> ### 🤖 NetBird Agent Network (Beta)
|
||||
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
39
agent-network/README.md
Normal file
39
agent-network/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# NetBird Agent Network
|
||||
|
||||
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||
policy, with no API keys to leak.
|
||||
|
||||
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||
> infrastructure.
|
||||
|
||||
## How it works
|
||||
|
||||
Agent Network is built on two existing NetBird capabilities:
|
||||
|
||||
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||
key server-side, forwards to the API or gateway, and records usage.
|
||||
|
||||
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||
|
||||
## Where the code lives
|
||||
|
||||
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||
components:
|
||||
|
||||
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||
and runs the per-request middleware pipeline.
|
||||
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||
and usage/access logs.
|
||||
|
||||
## Documentation
|
||||
|
||||
Full documentation, architecture, and quickstart:
|
||||
**https://docs.netbird.io/agent-network**
|
||||
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: string(activeProf.ID),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
196
client/cmd/service_privileged_test.go
Normal file
196
client/cmd/service_privileged_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
//go:build privileged
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,16 +1,12 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
106
client/cmd/up.go
106
client/cmd/up.go
@@ -401,6 +401,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
req.DisableVNCApproval = &disableVNCApproval
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -507,30 +513,14 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
ic.DisableVNCApproval = &disableVNCApproval
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
applySSHFlagsToConfig(cmd, &ic)
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -606,6 +596,49 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
}
|
||||
|
||||
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ttl := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &ttl
|
||||
}
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
@@ -635,31 +668,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
loginRequest.DisableVNCApproval = &disableVNCApproval
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
applySSHFlagsToLogin(cmd, &loginRequest)
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
100
client/cmd/vnc_agent.go
Normal file
100
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build windows || (darwin && !ios)
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var (
|
||||
vncAgentSocket string
|
||||
vncAgentTargetUID uint32
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
|
||||
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server inside the user's interactive session,
|
||||
// listening on a Unix-domain socket. The NetBird service spawns it: on
|
||||
// Windows via CreateProcessAsUser into the console session, on macOS via
|
||||
// launchctl asuser into the Aqua session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
if vncAgentSocket == "" {
|
||||
return fmt.Errorf("--socket is required")
|
||||
}
|
||||
|
||||
token := os.Getenv("NB_VNC_AGENT_TOKEN")
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
// Drop root privileges to the target console user BEFORE creating
|
||||
// the listening socket: keeps a post-auth bug in the encoder /
|
||||
// input / capture paths confined to the user's own privileges
|
||||
// rather than escalating to host root, and makes the daemon's
|
||||
// LOCAL_PEERCRED check see the right uid. No-op on Windows
|
||||
// (both processes run as SYSTEM) and when --target-uid is 0.
|
||||
if vncAgentTargetUID != 0 {
|
||||
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
|
||||
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
ln, err := net.Listen("unix", vncAgentSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
|
||||
}
|
||||
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
|
||||
log.Debugf("chmod %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return err
|
||||
}
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: token,
|
||||
Listener: ln,
|
||||
})
|
||||
|
||||
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
|
||||
return fmt.Errorf("start vnc server: %w", err)
|
||||
}
|
||||
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
|
||||
|
||||
<-cmd.Context().Done()
|
||||
log.Info("vnc-agent context cancelled, shutting down")
|
||||
return srv.Stop()
|
||||
},
|
||||
SilenceUsage: true,
|
||||
}
|
||||
18
client/cmd/vnc_agent_darwin.go
Normal file
18
client/cmd/vnc_agent_darwin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
|
||||
}
|
||||
return capturer, injector, nil
|
||||
}
|
||||
74
client/cmd/vnc_agent_dropprivs_darwin.go
Normal file
74
client/cmd/vnc_agent_dropprivs_darwin.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// dropAgentPrivileges drops the vnc-agent process from root (its
|
||||
// launchctl-asuser-inherited starting uid) to the target console user
|
||||
// before any other initialisation runs. Without this the agent runs as
|
||||
// root for the lifetime of the session; any post-auth memory-safety
|
||||
// issue in the capture/input/encode paths would then be a root-level
|
||||
// RCE on the host instead of a user-level one. Also makes the daemon's
|
||||
// LOCAL_PEERCRED check correctly identify the agent as the console user,
|
||||
// not as root.
|
||||
//
|
||||
// Returns an error when the agent is running as a non-root uid that
|
||||
// differs from targetUID: non-root can only setuid to itself, so a
|
||||
// mismatch here means the spawn went to the wrong session.
|
||||
func dropAgentPrivileges(targetUID uint32) error {
|
||||
if targetUID == 0 {
|
||||
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
|
||||
}
|
||||
cur := uint32(os.Getuid())
|
||||
if cur == targetUID {
|
||||
return nil
|
||||
}
|
||||
if cur != 0 {
|
||||
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
|
||||
}
|
||||
// Resolve the target user's real primary group rather than reusing
|
||||
// targetUID as the gid: a user's primary group on macOS is typically
|
||||
// staff(20), not gid==uid. Fail closed if the lookup fails.
|
||||
targetGID, err := primaryGroupID(targetUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Drop supplementary groups first: setgid alone doesn't touch the
|
||||
// auxiliary group list, leaving root's groups attached would let the
|
||||
// dropped process write to root-only group-writable files.
|
||||
if err := syscall.Setgroups([]int{}); err != nil {
|
||||
return fmt.Errorf("setgroups([]): %w", err)
|
||||
}
|
||||
if err := syscall.Setgid(targetGID); err != nil {
|
||||
return fmt.Errorf("setgid(%d): %w", targetGID, err)
|
||||
}
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid(%d): %w", targetUID, err)
|
||||
}
|
||||
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
|
||||
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// primaryGroupID resolves the real primary group id of the user with the
|
||||
// given uid. Fails closed: a lookup or parse error returns an error so the
|
||||
// caller never falls back to using uid as the gid.
|
||||
func primaryGroupID(targetUID uint32) (int, error) {
|
||||
u, err := user.LookupId(strconv.Itoa(int(targetUID)))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("look up uid %d: %w", targetUID, err)
|
||||
}
|
||||
gid, err := strconv.Atoi(u.Gid)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse gid %q for uid %d: %w", u.Gid, targetUID, err)
|
||||
}
|
||||
return gid, nil
|
||||
}
|
||||
55
client/cmd/vnc_agent_dropprivs_darwin_test.go
Normal file
55
client/cmd/vnc_agent_dropprivs_darwin_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
|
||||
// dropAgentPrivileges must never be a no-op when asked to keep the
|
||||
// agent as root (target uid 0). A future caller that passes 0 by
|
||||
// mistake would otherwise leave the post-auth attack surface running
|
||||
// with full root privileges.
|
||||
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
|
||||
err := dropAgentPrivileges(0)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal for target uid 0, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "root") {
|
||||
t.Fatalf("error should mention root, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
|
||||
// where the agent is launched by hand as the target user (no root
|
||||
// available, no setuid needed). The helper must succeed silently
|
||||
// instead of trying (and failing) a setuid to its current uid.
|
||||
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
|
||||
// Skip when running as root: the early-return path we want to
|
||||
// cover only fires when current uid == target uid.
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; cannot exercise the no-op early-return")
|
||||
}
|
||||
if err := dropAgentPrivileges(uid); err != nil {
|
||||
t.Fatalf("expected no-op when current uid == target, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
|
||||
// caller tries to setuid to a different uid" path: setuid would fail
|
||||
// with EPERM anyway, but the helper should surface a clear error
|
||||
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
|
||||
// flag) is debuggable.
|
||||
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; covered case requires non-root caller")
|
||||
}
|
||||
err := dropAgentPrivileges(uid + 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
|
||||
}
|
||||
}
|
||||
11
client/cmd/vnc_agent_dropprivs_testhelpers_darwin.go
Normal file
11
client/cmd/vnc_agent_dropprivs_testhelpers_darwin.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import "os"
|
||||
|
||||
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
|
||||
// without leaking an os import into the test file itself.
|
||||
func currentUIDForTest() uint32 {
|
||||
return uint32(os.Getuid())
|
||||
}
|
||||
14
client/cmd/vnc_agent_dropprivs_windows.go
Normal file
14
client/cmd/vnc_agent_dropprivs_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
|
||||
// both run as SYSTEM (the daemon spawns the agent into the interactive
|
||||
// session via CreateProcessAsUser with an impersonation token, but the
|
||||
// resulting process still runs under SYSTEM, not under the user's
|
||||
// account). The Windows path relies on the DACL-restricted socket
|
||||
// directory, the unpredictable per-spawn socket name, the listen-readiness
|
||||
// gate, and the per-spawn token for integrity instead.
|
||||
func dropAgentPrivileges(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
15
client/cmd/vnc_agent_windows.go
Normal file
15
client/cmd/vnc_agent_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent running in Windows session %d", sessionID)
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
|
||||
}
|
||||
16
client/cmd/vnc_flags.go
Normal file
16
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCApprovalFlag = "disable-vnc-approval"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCApproval bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
|
||||
}
|
||||
@@ -6,19 +6,30 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var StateDir string
|
||||
var (
|
||||
// StateDir holds persistent state (config, profiles, install metadata).
|
||||
StateDir string
|
||||
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
|
||||
// such as Unix sockets for daemon and per-session IPC. Empty on
|
||||
// platforms without a conventional /var/run-style location.
|
||||
RuntimeDir string
|
||||
)
|
||||
|
||||
func init() {
|
||||
StateDir = os.Getenv("NB_STATE_DIR")
|
||||
if StateDir != "" {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
case "darwin", "linux":
|
||||
StateDir = "/var/lib/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
StateDir = "/var/db/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
}
|
||||
if v := os.Getenv("NB_STATE_DIR"); v != "" {
|
||||
StateDir = v
|
||||
}
|
||||
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
|
||||
RuntimeDir = v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
// ConnectClient.Stop now cancels its own run context and waits for the
|
||||
// run loop to tear the engine down, so this cancel() is no longer
|
||||
// required to break the deadlock and could be removed. It is kept as a
|
||||
// defensive belt-and-suspenders: cancelling the parent context first
|
||||
// guarantees the run loop is unblocked even if Stop's contract regresses.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && privileged
|
||||
|
||||
package iptables
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build !android && privileged
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || !privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package wgproxy
|
||||
|
||||
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||
wgPort := 51852
|
||||
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||
wgPort := 51850
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("192.168.0.56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||
wgPort := 51851
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
// NetBird UDP address of the remote peer
|
||||
nbAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("100.108.111.177"),
|
||||
Port: 38746,
|
||||
}
|
||||
|
||||
p2pEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP("fe80::56"),
|
||||
Port: 51820,
|
||||
}
|
||||
|
||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||
}
|
||||
|
||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||
wgPort := 51856
|
||||
|
||||
219
client/internal/approval/broker.go
Normal file
219
client/internal/approval/broker.go
Normal file
@@ -0,0 +1,219 @@
|
||||
// Package approval brokers per-attempt user-accept prompts for inbound
|
||||
// remote access (VNC today, SSH and others in the future). A caller pushes
|
||||
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
|
||||
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
|
||||
// request timeout fires, or no subscriber is connected. The latter case
|
||||
// fails closed so a backgrounded UI cannot silently bypass the gate.
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
|
||||
// should not set these themselves; values in Prompt.Metadata that collide
|
||||
// are overwritten by the broker.
|
||||
const (
|
||||
MetaRequestID = "request_id"
|
||||
MetaKind = "kind"
|
||||
MetaExpiresAt = "expires_at"
|
||||
)
|
||||
|
||||
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
|
||||
// short, eyeball-able fingerprint to display in the approval dialog.
|
||||
// The dashboard-supplied display name attached to a SessionPubKey isn't
|
||||
// cryptographically asserted by the connecting client, so the prompt
|
||||
// must also show something that IS: the key fingerprint, a hash of
|
||||
// the static public key the client just proved possession of during the
|
||||
// Noise handshake. Returns the empty string when the input is too short
|
||||
// to plausibly be a hex pubkey, so the row is omitted rather than
|
||||
// rendered as a misleading partial.
|
||||
//
|
||||
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
|
||||
// fingerprint, resistant to random-prefix collisions and easy for a human
|
||||
// to compare with an out-of-band reference).
|
||||
func ShortKeyFingerprint(hexKey string) string {
|
||||
if len(hexKey) < 8 {
|
||||
return ""
|
||||
}
|
||||
src := hexKey
|
||||
if len(src) > 16 {
|
||||
src = src[:16]
|
||||
}
|
||||
var out []byte
|
||||
for i, c := range src {
|
||||
if i > 0 && i%4 == 0 {
|
||||
out = append(out, '-')
|
||||
}
|
||||
out = append(out, byte(c))
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Kind values for the well-known prompt subjects. New subsystems should
|
||||
// add a constant here so the UI can dispatch on a known string.
|
||||
const (
|
||||
KindVNC = "vnc"
|
||||
KindSSH = "ssh"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the wall-clock window the user has to accept or deny a
|
||||
// pending approval before the broker fails closed and returns ErrTimeout.
|
||||
// Kept well under typical VNC client and dashboard connection timeouts so
|
||||
// the RFB rejection actually reaches the browser instead of racing the
|
||||
// browser's own "connection timed out" message.
|
||||
const DefaultTimeout = 15 * time.Second
|
||||
|
||||
// timeoutValue returns the active timeout. It's a var so tests in this
|
||||
// package can shorten the wait without exposing a setter on the public
|
||||
// API. Production code always sees DefaultTimeout.
|
||||
var timeoutValue = func() time.Duration { return DefaultTimeout }
|
||||
|
||||
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
|
||||
// The caller must reject the underlying connection (fail-closed).
|
||||
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
|
||||
|
||||
// ErrTimeout indicates the user did not respond within DefaultTimeout.
|
||||
var ErrTimeout = errors.New("approval timed out")
|
||||
|
||||
// ErrDenied indicates the user explicitly denied the connection.
|
||||
var ErrDenied = errors.New("approval denied")
|
||||
|
||||
// EventPublisher is the subset of peer.Status used to emit prompts.
|
||||
type EventPublisher interface {
|
||||
PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
HasEventSubscribers() bool
|
||||
}
|
||||
|
||||
// Prompt describes the pending request shown to the user. Kind selects
|
||||
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
|
||||
// one-liner the UI may show as a title or notification body. Metadata is
|
||||
// passed through verbatim and is the subsystem-specific payload (peer
|
||||
// name, source IP, mode, etc.).
|
||||
type Prompt struct {
|
||||
Kind string
|
||||
Subject string
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// Decision carries the user's response to an approval prompt. ViewOnly is
|
||||
// only meaningful when Accept is true; it lets the host grant the
|
||||
// connection but signal the requester that input control is withheld.
|
||||
type Decision struct {
|
||||
Accept bool
|
||||
ViewOnly bool
|
||||
}
|
||||
|
||||
// Broker holds in-flight approval requests keyed by request ID.
|
||||
type Broker struct {
|
||||
pub EventPublisher
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[string]chan Decision
|
||||
}
|
||||
|
||||
// New returns a broker that publishes prompts via pub.
|
||||
func New(pub EventPublisher) *Broker {
|
||||
return &Broker{
|
||||
pub: pub,
|
||||
pending: make(map[string]chan Decision),
|
||||
}
|
||||
}
|
||||
|
||||
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
|
||||
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
|
||||
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
|
||||
// otherwise. Callers must treat any non-nil error as a deny.
|
||||
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
|
||||
var zero Decision
|
||||
if b == nil || b.pub == nil {
|
||||
return zero, fmt.Errorf("approval broker not configured")
|
||||
}
|
||||
if !b.pub.HasEventSubscribers() {
|
||||
return zero, ErrNoSubscriber
|
||||
}
|
||||
|
||||
id := uuid.NewString()
|
||||
resp := make(chan Decision, 1)
|
||||
|
||||
b.mu.Lock()
|
||||
b.pending[id] = resp
|
||||
b.mu.Unlock()
|
||||
|
||||
defer b.dropPending(id)
|
||||
|
||||
timeout := timeoutValue()
|
||||
expiresAt := time.Now().Add(timeout)
|
||||
meta := make(map[string]string, len(p.Metadata)+3)
|
||||
for k, v := range p.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
meta[MetaRequestID] = id
|
||||
meta[MetaKind] = p.Kind
|
||||
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
|
||||
|
||||
subject := p.Subject
|
||||
if subject == "" {
|
||||
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
|
||||
}
|
||||
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
|
||||
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case d := <-resp:
|
||||
if !d.Accept {
|
||||
return zero, ErrDenied
|
||||
}
|
||||
return d, nil
|
||||
case <-timer.C:
|
||||
return zero, ErrTimeout
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Respond delivers the user's decision for id. Returns true when a pending
|
||||
// request matched and was woken, false when id was unknown or already done.
|
||||
func (b *Broker) Respond(id string, d Decision) bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
b.mu.Lock()
|
||||
ch, ok := b.pending[id]
|
||||
if ok {
|
||||
delete(b.pending, id)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case ch <- d:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Broker) dropPending(id string) {
|
||||
b.mu.Lock()
|
||||
delete(b.pending, id)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
434
client/internal/approval/broker_test.go
Normal file
434
client/internal/approval/broker_test.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakePublisher records published events and reports whether subscribers
|
||||
// are connected. The subscribers flag is the security-critical signal:
|
||||
// when false the broker must refuse to emit and the gate must fail closed.
|
||||
type fakePublisher struct {
|
||||
mu sync.Mutex
|
||||
subscribers bool
|
||||
events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
func (p *fakePublisher) PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
p.mu.Lock()
|
||||
p.events = append(p.events, &proto.SystemEvent{
|
||||
Severity: severity,
|
||||
Category: category,
|
||||
Message: msg,
|
||||
UserMessage: userMsg,
|
||||
Metadata: metadata,
|
||||
})
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *fakePublisher) HasEventSubscribers() bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.subscribers
|
||||
}
|
||||
|
||||
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
|
||||
t.Helper()
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
require.NotEmpty(t, p.events, "publisher saw no events")
|
||||
return p.events[len(p.events)-1]
|
||||
}
|
||||
|
||||
func (p *fakePublisher) eventCount() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.events)
|
||||
}
|
||||
|
||||
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
|
||||
// when the UI is not subscribed, the broker must refuse without emitting
|
||||
// an event or arming a waiter. A regression here is a silent bypass.
|
||||
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: false}
|
||||
b := New(pub)
|
||||
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrNoSubscriber)
|
||||
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
|
||||
|
||||
b.mu.Lock()
|
||||
pending := len(b.pending)
|
||||
b.mu.Unlock()
|
||||
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
|
||||
}
|
||||
|
||||
// TestRequestTimeoutDenies verifies that a request without a UI response
|
||||
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
|
||||
// per-test broker timeout via Respond after the fact to keep the test fast.
|
||||
func TestRequestTimeoutDenies(t *testing.T) {
|
||||
// Replace DefaultTimeout for the lifetime of this test.
|
||||
orig := DefaultTimeout
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, orig)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
start := time.Now()
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
|
||||
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
|
||||
}
|
||||
|
||||
// TestRequestDenied returns ErrDenied when the UI responds with false.
|
||||
func TestRequestDenied(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
var requestID string
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
requestID = waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(requestID, Decision{Accept: false}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(false)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
|
||||
// gate but breaks the feature.
|
||||
func TestRequestAccepted(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(true)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
|
||||
// engine shutting down mid-prompt) returns the cancel error rather than
|
||||
// nil. A nil here would be a silent bypass on shutdown races.
|
||||
func TestRequestCtxCancelDenies(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
// Wait until the prompt is in flight so cancel races a live waiter.
|
||||
_ = waitForRequestID(t, pub)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after ctx cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
|
||||
// affect or accidentally accept any in-flight request whose id it doesn't
|
||||
// match. Also confirms it doesn't panic.
|
||||
func TestRespondUnknownIsNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
// No in-flight prompts: Respond returns false.
|
||||
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
|
||||
|
||||
// With an in-flight prompt, a wrong id still returns false and the
|
||||
// prompt remains armed (eventually timing out as a deny).
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
realID := waitForRequestID(t, pub)
|
||||
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
|
||||
assert.NotEqual(t, "totally-bogus", realID)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondAfterTimeoutNoop confirms a late accept response can't
|
||||
// retroactively flip a denied (timed-out) request. The dropPending defer
|
||||
// in Request must have removed the entry by the time Respond races in.
|
||||
func TestRespondAfterTimeoutNoop(t *testing.T) {
|
||||
defaultTimeout(t, 30*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.ErrorIs(t, err, ErrTimeout)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not time out")
|
||||
}
|
||||
|
||||
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
|
||||
}
|
||||
|
||||
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
|
||||
// past the matched waiter or panic on a closed/full channel.
|
||||
func TestRespondDoubleNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNilBrokerRequestErrors guards the engine pre-init path where the
|
||||
// broker may not yet exist (or its publisher is nil): Request must
|
||||
// error, never silently accept.
|
||||
func TestNilBrokerRequestErrors(t *testing.T) {
|
||||
var b *Broker
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "nil broker must error, never silently accept")
|
||||
|
||||
b2 := New(nil)
|
||||
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
|
||||
}
|
||||
|
||||
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
|
||||
// and expires_at on the emitted event. The UI relies on these keys; if
|
||||
// they are dropped, the user cannot route the prompt and the response
|
||||
// path breaks (which fails closed via timeout).
|
||||
func TestPromptMetadataInjected(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{
|
||||
Kind: KindVNC,
|
||||
Subject: "VNC connection from peerA",
|
||||
Metadata: map[string]string{"peer_name": "peerA"},
|
||||
})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
ev := pub.lastEvent(t)
|
||||
|
||||
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
|
||||
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
|
||||
assert.Equal(t, id, ev.Metadata[MetaRequestID])
|
||||
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
|
||||
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
|
||||
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestConcurrentRequests verifies that two concurrent prompts are tracked
|
||||
// independently. A bug that aliases ids would let one Respond unblock
|
||||
// the wrong waiter (a silent accept across prompts).
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
const n = 20
|
||||
results := make(chan error, n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
}
|
||||
|
||||
ids := waitForNRequestIDs(t, pub, n)
|
||||
require.Len(t, ids, n)
|
||||
|
||||
// Deny exactly half, accept the rest. Track outcome per id so we can
|
||||
// match each Request's return value against the response we sent.
|
||||
denySet := make(map[string]bool, n)
|
||||
for i, id := range ids {
|
||||
deny := i%2 == 0
|
||||
denySet[id] = deny
|
||||
require.True(t, b.Respond(id, Decision{Accept: !deny}))
|
||||
}
|
||||
|
||||
// Collect all returns and check no nil errors slipped past a deny.
|
||||
var accepted, denied atomic.Int32
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
if err == nil {
|
||||
accepted.Add(1)
|
||||
} else {
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
denied.Add(1)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("only got %d/%d responses", i, n)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, int32(n/2), denied.Load())
|
||||
assert.Equal(t, int32(n/2), accepted.Load())
|
||||
}
|
||||
|
||||
// waitForRequestID blocks until the publisher sees its next event and
|
||||
// returns the request_id stamped on it.
|
||||
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
var id string
|
||||
if count > 0 {
|
||||
id = pub.events[count-1].Metadata[MetaRequestID]
|
||||
}
|
||||
pub.mu.Unlock()
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timeout waiting for emitted event")
|
||||
return ""
|
||||
}
|
||||
|
||||
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
pub.mu.Unlock()
|
||||
if count >= n {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
pub.mu.Lock()
|
||||
defer pub.mu.Unlock()
|
||||
out := make([]string, 0, len(pub.events))
|
||||
seen := make(map[string]struct{}, len(pub.events))
|
||||
for _, ev := range pub.events {
|
||||
id := ev.Metadata[MetaRequestID]
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) < n {
|
||||
t.Fatalf("only got %d/%d request ids", len(out), n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// defaultTimeout swaps the broker's per-request wall-clock window so the
|
||||
// timeout tests run quickly. Restores the prior value on the next call.
|
||||
func defaultTimeout(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
if d <= 0 {
|
||||
t.Fatal("defaultTimeout must be > 0")
|
||||
}
|
||||
timeoutValue = func() time.Duration { return d }
|
||||
}
|
||||
|
||||
// requestErr wraps Broker.Request to drop the Decision when tests only
|
||||
// care about the error path. Keeps the goroutine bodies tight.
|
||||
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
|
||||
_, err := b.Request(ctx, p)
|
||||
return err
|
||||
}
|
||||
|
||||
// TestRequestViewOnly checks the view-only outcome flows through Request's
|
||||
// Decision return without being silently swallowed.
|
||||
func TestRequestViewOnly(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
type result struct {
|
||||
d Decision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
done <- result{d, err}
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
|
||||
|
||||
select {
|
||||
case r := <-done:
|
||||
assert.NoError(t, r.err)
|
||||
assert.True(t, r.d.Accept)
|
||||
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("view-only request did not resolve")
|
||||
}
|
||||
}
|
||||
62
client/internal/approval/fingerprint_test.go
Normal file
62
client/internal/approval/fingerprint_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package approval
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestShortKeyFingerprint locks in the format the VNC approval prompt
|
||||
// shows to the user. The fingerprint is the user's only cryptographic
|
||||
// anchor against a malicious management server that pushes a spoofed
|
||||
// display name, so accidental changes to its format would silently
|
||||
// undermine that defence.
|
||||
func TestShortKeyFingerprint(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full_32_byte_pubkey",
|
||||
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "exactly_16_chars",
|
||||
in: "0123456789abcdef",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "borderline_8_chars",
|
||||
in: "01234567",
|
||||
want: "0123-4567",
|
||||
},
|
||||
{
|
||||
name: "too_short_returns_empty",
|
||||
in: "0123",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty_returns_empty",
|
||||
in: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShortKeyFingerprint(tc.in)
|
||||
if got != tc.want {
|
||||
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
|
||||
// formatting bug that would collapse different prefixes onto the same
|
||||
// displayed fingerprint and let an attacker substitute their pubkey for
|
||||
// a victim's while keeping the prompt visually identical.
|
||||
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
|
||||
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
|
||||
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
|
||||
if a == b {
|
||||
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
|
||||
}
|
||||
}
|
||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
runCancel context.CancelFunc
|
||||
runExited chan struct{}
|
||||
runOnce sync.Once
|
||||
runStarted atomic.Bool
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -70,8 +75,14 @@ func NewConnectClient(
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
) *ConnectClient {
|
||||
// Derive the run context here so Stop owns the cancel that unblocks the run
|
||||
// loop. runCancel is set once at construction, so Stop can call it without
|
||||
// racing the run loop's startup. Callers therefore need not cancel before Stop.
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
ctx: runCtx,
|
||||
runCancel: runCancel,
|
||||
runExited: make(chan struct{}),
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
@@ -135,6 +146,11 @@ func (c *ConnectClient) RunOniOS(
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
// Mark the loop as started and signal exit on return so Stop can wait for
|
||||
// the loop to finish (and skip the wait if the loop never ran).
|
||||
c.runStarted.Store(true)
|
||||
defer c.runOnce.Do(func() { close(c.runExited) })
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -290,7 +306,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
c.runCancel()
|
||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||
}
|
||||
return wrapErr(err)
|
||||
@@ -410,14 +426,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
// todo: consider to remove this condition. Is not thread safe.
|
||||
// We should always call Stop(), but we need to verify that it is idempotent
|
||||
if engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
@@ -433,12 +445,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
c.statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
state.Set(StatusNeedsLogin)
|
||||
_ = c.Stop()
|
||||
c.runCancel()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -516,11 +528,9 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
c.runCancel()
|
||||
if c.runStarted.Load() {
|
||||
<-c.runExited
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -571,6 +581,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
DisableVNCApproval: config.DisableVNCApproval,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -653,6 +665,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
|
||||
@@ -655,6 +655,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
if g.internalConfig.ServerVNCAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
|
||||
}
|
||||
if g.internalConfig.DisableVNCApproval != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -864,6 +864,8 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
ServerVNCAllowed: &bTrue,
|
||||
DisableVNCApproval: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
|
||||
@@ -51,13 +51,20 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -76,9 +83,10 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
183
client/internal/dns/mgmt/mgmt_resolve_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
|
||||
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
|
||||
// RFC 6891 a responder must not include an OPT RR toward a client that did not
|
||||
// advertise EDNS0.
|
||||
func StripOPT(msg *dns.Msg) {
|
||||
if len(msg.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := msg.Extra[:0]
|
||||
for _, rr := range msg.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
msg.Extra = out
|
||||
}
|
||||
|
||||
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
|
||||
// the message, if present.
|
||||
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
|
||||
opt := msg.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if ede, ok := o.(*dns.EDNS0_EDE); ok {
|
||||
return ede, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
StripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestExtractEDE(t *testing.T) {
|
||||
t.Run("no edns", func(t *testing.T) {
|
||||
_, ok := ExtractEDE(&dns.Msg{})
|
||||
assert.False(t, ok, "message without OPT has no EDE")
|
||||
})
|
||||
|
||||
t.Run("edns without ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
rm.SetEdns0(4096, false)
|
||||
_, ok := ExtractEDE(rm)
|
||||
assert.False(t, ok, "OPT without EDE option returns false")
|
||||
})
|
||||
|
||||
t.Run("with ede", func(t *testing.T) {
|
||||
rm := &dns.Msg{}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
|
||||
rm.Extra = append(rm.Extra, opt)
|
||||
|
||||
ede, ok := ExtractEDE(rm)
|
||||
assert.True(t, ok, "EDE option should be found")
|
||||
assert.Equal(t, uint16(49152), ede.InfoCode)
|
||||
assert.Equal(t, "upstream timeout", ede.ExtraText)
|
||||
})
|
||||
}
|
||||
|
||||
485
client/internal/dns/server_privileged_test.go
Normal file
485
client/internal/dns/server_privileged_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
//go:build privileged
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,7 +22,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -104,466 +102,6 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap []handlerWrapper
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap []handlerWrapper
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
{
|
||||
domain: nbdns.RootZone,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: "netbird.io",
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
{
|
||||
domain: "netbird.cloud",
|
||||
priority: PriorityLocal,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: nil,
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: []handlerWrapper{{
|
||||
domain: ".",
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &mockHandler{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: nil,
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = wgIface.Close()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
err = dnsServer.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||
}
|
||||
|
||||
for _, expected := range testCase.expectedUpstreamMap {
|
||||
found := false
|
||||
for _, got := range dnsServer.dnsMuxHandlers {
|
||||
if got.domain == expected.domain && got.priority == expected.priority {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||
}
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Errorf("create and init wireguard interface: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = wgIface.Close(); err != nil {
|
||||
t.Logf("close wireguard interface: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("run DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||
t.Logf("restore DNS settings on the host: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||
{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the server with regular configuration
|
||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update2 := update
|
||||
update2.ServiceEnable = false
|
||||
// Disable the server, stop the listener
|
||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
update3 := update2
|
||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||
// But service still get updates and we checking that we handle
|
||||
// internal state in the right way
|
||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -457,7 +457,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
// problems: fail over for a better answer but keep the upstream healthy.
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
@@ -466,7 +466,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
resutil.StripOPT(rm)
|
||||
}
|
||||
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
@@ -523,22 +523,6 @@ func upstreamUDPSize() uint16 {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
|
||||
@@ -985,19 +985,6 @@ func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
@@ -26,6 +26,15 @@ import (
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
// EDE info codes the forwarder emits on upstream failures so the querying
|
||||
// client can see the reason without inspecting this peer's logs. They live in
|
||||
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
|
||||
// real upstream EDE here, so these cannot collide with a genuine code.
|
||||
const (
|
||||
edeNetbirdUpstreamTimeout uint16 = 49152
|
||||
edeNetbirdUpstreamFailure uint16 = 49153
|
||||
)
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
reqHasEdns bool,
|
||||
startTime time.Time,
|
||||
) {
|
||||
qType := question.Qtype
|
||||
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
}
|
||||
|
||||
if reqHasEdns {
|
||||
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
|
||||
}
|
||||
|
||||
f.writeResponse(logger, w, resp, domain, startTime)
|
||||
}
|
||||
|
||||
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
|
||||
func edeCodeFor(dnsErr *net.DNSError) uint16 {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return edeNetbirdUpstreamTimeout
|
||||
}
|
||||
return edeNetbirdUpstreamFailure
|
||||
}
|
||||
|
||||
// edeText builds the EDE extra-text describing the class of upstream failure.
|
||||
// It deliberately omits the upstream server address, which may be an internal
|
||||
// resolver and is exposed to any client permitted to use the route; the full
|
||||
// detail stays in the forwarder's local log.
|
||||
func edeText(dnsErr *net.DNSError) string {
|
||||
if dnsErr != nil && dnsErr.IsTimeout {
|
||||
return "netbird forwarder: upstream timeout"
|
||||
}
|
||||
return "netbird forwarder: upstream failure"
|
||||
}
|
||||
|
||||
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
|
||||
// creating the OPT pseudo-record if the response does not already carry one.
|
||||
func attachEDE(resp *dns.Msg, code uint16, text string) {
|
||||
opt := resp.IsEdns0()
|
||||
if opt == nil {
|
||||
resp.SetEdns0(dns.DefaultMsgSize, false)
|
||||
opt = resp.IsEdns0()
|
||||
}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lookupErr error
|
||||
reqEdns bool
|
||||
wantEDE bool
|
||||
wantCode uint16
|
||||
wantTextHas string
|
||||
}{
|
||||
{
|
||||
name: "timeout with edns0",
|
||||
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamTimeout,
|
||||
wantTextHas: "netbird forwarder: upstream timeout",
|
||||
},
|
||||
{
|
||||
name: "server failure with edns0",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: true,
|
||||
wantEDE: true,
|
||||
wantCode: edeNetbirdUpstreamFailure,
|
||||
wantTextHas: "netbird forwarder: upstream failure",
|
||||
},
|
||||
{
|
||||
name: "no edns0 in request omits ede",
|
||||
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
|
||||
reqEdns: false,
|
||||
wantEDE: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr(nil), tt.lookupErr).Once()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
if tt.reqEdns {
|
||||
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||
mockResolver.AssertExpectations(t)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected a response")
|
||||
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
|
||||
|
||||
ede, ok := resutil.ExtractEDE(writtenResp)
|
||||
if !tt.wantEDE {
|
||||
assert.False(t, ok, "response must not carry EDE")
|
||||
return
|
||||
}
|
||||
require.True(t, ok, "response must carry EDE")
|
||||
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
|
||||
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
|
||||
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -86,6 +87,8 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
var ErrEngineAlreadyStarted = errors.New("engine already started")
|
||||
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -123,6 +126,8 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -199,6 +204,8 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
started bool
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
@@ -208,7 +215,9 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
approvalBroker *approval.Broker
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -279,9 +288,15 @@ func NewEngine(
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
// The engine is single-use: a fresh instance is built per connection
|
||||
// cycle (see Client.run), so the run context is created once here rather
|
||||
// than in Start.
|
||||
ctx, cancel := context.WithCancel(clientCtx)
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
@@ -294,6 +309,7 @@ func NewEngine(
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
approvalBroker: approval.New(services.StatusRecorder),
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
@@ -314,8 +330,34 @@ func (e *Engine) Stop() error {
|
||||
log.Debugf("tried stopping engine that is nil")
|
||||
return nil
|
||||
}
|
||||
e.cancel()
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
e.stopLocked()
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked tears down everything Start may have brought up, in the order
|
||||
// teardown requires (DNS before the interface goes down, flow manager after).
|
||||
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
|
||||
// path, so a partially-initialized engine is cleaned up the same way; every
|
||||
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
|
||||
// after releasing the lock, since the goroutines also take syncMsgMux.
|
||||
func (e *Engine) stopLocked() {
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
}
|
||||
@@ -330,6 +372,10 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -366,10 +412,6 @@ func (e *Engine) Stop() error {
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
@@ -388,21 +430,6 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
@@ -440,18 +467,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if err := iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
// The engine is single-use. Reject a duplicate start and a start on an
|
||||
// already-stopped engine (run context cancelled).
|
||||
if e.started {
|
||||
return ErrEngineAlreadyStarted
|
||||
}
|
||||
|
||||
if ctxErr := e.ctx.Err(); ctxErr != nil {
|
||||
return fmt.Errorf("engine already stopped: %w", ctxErr)
|
||||
}
|
||||
|
||||
e.started = true
|
||||
|
||||
// Tear down any partially-initialized state on a failed start. Cancel the
|
||||
// run context first so goroutines started before the failure (connMgr,
|
||||
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
|
||||
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
|
||||
// not just what close() covers.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.cancel()
|
||||
e.stopLocked()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = iface.ValidateMTU(e.config.MTU); err != nil {
|
||||
return fmt.Errorf("invalid MTU configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
@@ -485,13 +532,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
@@ -526,7 +571,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -535,7 +579,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -547,7 +590,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
@@ -572,9 +614,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
}
|
||||
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
if err := e.dnsServer.Initialize(); err != nil {
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
@@ -586,7 +626,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
if err = e.receiveSignalEvents(); err != nil {
|
||||
return err
|
||||
}
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
@@ -638,7 +680,6 @@ func (e *Engine) createFirewall() error {
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
@@ -1035,7 +1076,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
@@ -1044,6 +1085,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1066,6 +1108,20 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||
// can be excluded from the reported network addresses; the interface coming and
|
||||
// going otherwise churns the peer meta on the management server.
|
||||
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||
var ips []netip.Addr
|
||||
if e.config.WgAddr.IP.IsValid() {
|
||||
ips = append(ips, e.config.WgAddr.IP)
|
||||
}
|
||||
if e.config.WgAddr.HasIPv6() {
|
||||
ips = append(ips, e.config.WgAddr.IPv6)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -1091,6 +1147,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
@@ -1209,7 +1269,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
@@ -1218,6 +1278,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1407,6 +1468,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// VNC auth: always sync, including nil so cleared auth on the management
|
||||
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
|
||||
// cleanup path.
|
||||
e.updateVNCServerAuth(networkMap.GetVncAuth())
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
@@ -1698,7 +1764,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
func (e *Engine) receiveSignalEvents() error {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
@@ -1769,7 +1835,12 @@ func (e *Engine) receiveSignalEvents() {
|
||||
}
|
||||
}()
|
||||
|
||||
e.signal.WaitStreamConnected()
|
||||
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
|
||||
e.signal.WaitStreamConnected(e.ctx)
|
||||
if err := e.ctx.Err(); err != nil {
|
||||
return fmt.Errorf("wait for signal stream: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
@@ -1881,6 +1952,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -2649,3 +2721,16 @@ func decodeRelayIP(b []byte) netip.Addr {
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's decision for a pending approval to
|
||||
// the broker. viewOnly is honoured only when accept is true. Returns
|
||||
// true when the request_id matched a live prompt.
|
||||
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
|
||||
if e == nil || e.approvalBroker == nil {
|
||||
return false
|
||||
}
|
||||
return e.approvalBroker.Respond(requestID, approval.Decision{
|
||||
Accept: accept,
|
||||
ViewOnly: accept && viewOnly,
|
||||
})
|
||||
}
|
||||
|
||||
565
client/internal/engine_privileged_test.go
Normal file
565
client/internal/engine_privileged_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
//go:build privileged
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers were started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
@@ -237,22 +237,18 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
NetworkValidation: wgAddr,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
|
||||
@@ -6,37 +6,18 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -50,18 +31,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
@@ -69,25 +39,9 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
@@ -234,129 +188,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.21/24"},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||
},
|
||||
}
|
||||
|
||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||
networkMap := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// SSH server is enabled, therefore SSH config should be applied
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 8,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 9,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
@@ -426,7 +257,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
@@ -631,97 +462,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -817,7 +557,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1024,7 +764,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
@@ -1105,104 +845,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
// log.SetLevel(log.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
mu := sync.Mutex{}
|
||||
engines := []*Engine{}
|
||||
numPeers := 10
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(numPeers)
|
||||
// create and start peers
|
||||
for i := 0; i < numPeers; i++ {
|
||||
j := i
|
||||
go func() {
|
||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
engines = append(engines, engine)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until all have been created and started
|
||||
wg.Wait()
|
||||
if len(engines) != numPeers {
|
||||
t.Fatal("not all peers was started")
|
||||
}
|
||||
// check whether all the peer have expected peers connected
|
||||
|
||||
expectedConnected := numPeers * (numPeers - 1)
|
||||
|
||||
// adjust according to timeouts
|
||||
timeout := 50 * time.Second
|
||||
timeoutChan := time.After(timeout)
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
totalConnected := 0
|
||||
for _, engine := range engines {
|
||||
totalConnected += getConnectedPeers(engine)
|
||||
}
|
||||
if totalConnected == expectedConnected {
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
break loop
|
||||
}
|
||||
log.Infof("total connected=%d", totalConnected)
|
||||
}
|
||||
}
|
||||
// cleanup test
|
||||
for n, peerEngine := range engines {
|
||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||
errStop := peerEngine.mgmClient.Close()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||
}
|
||||
errStop = peerEngine.Stop()
|
||||
if errStop != nil {
|
||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
ifaceList, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -1526,187 +1168,6 @@ func TestCompareNetIPLists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifaceName string
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
conf := &EngineConfig{
|
||||
WgIfaceName: ifaceName,
|
||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||
WgPrivateKey: key,
|
||||
WgPort: wgPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.IsConnected() {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
|
||||
302
client/internal/engine_vnc.go
Normal file
302
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,302 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/vnc"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
ActiveSessions() []vncserver.ActiveSessionInfo
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector, ok := newPlatformVNC()
|
||||
if !ok {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
var sessionRecorder func(vncserver.SessionTick)
|
||||
if e.clientMetrics != nil {
|
||||
sessionRecorder = func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
Writes: t.Writes,
|
||||
FBUs: t.FBUs,
|
||||
MaxFBUBytes: t.MaxFBUBytes,
|
||||
MaxFBURects: t.MaxFBURects,
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
}
|
||||
}
|
||||
serviceMode := vncNeedsServiceMode()
|
||||
if serviceMode {
|
||||
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
|
||||
}
|
||||
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
IdentityKey: e.config.WgPrivateKey[:],
|
||||
ServiceMode: serviceMode,
|
||||
SessionRecorder: sessionRecorder,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
RequireApproval: requireApproval,
|
||||
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
|
||||
})
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
// A nil vncAuth clears all authorized users and session pubkeys so management
|
||||
// can revoke access by omitting the field on the next sync.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if vncAuth == nil {
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{})
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, pk := range vncAuth.GetSessionPubKeys() {
|
||||
pub := pk.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := pk.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
DisplayName: pk.GetDisplayName(),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running and the list
|
||||
// of active VNC sessions. The pointer is captured under syncMsgMux so a
|
||||
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
|
||||
// check and the ActiveSessions call.
|
||||
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
vncSrv := e.vncSrv
|
||||
e.syncMsgMux.Unlock()
|
||||
if vncSrv == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, vncSrv.ActiveSessions()
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if e.wgInterface != nil && e.wgInterface.GetNet() != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// vncApprover adapts the generic approval.Broker for the VNC server.
|
||||
type vncApprover struct {
|
||||
broker *approval.Broker
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
|
||||
// Resolve the source overlay IP to a peer FQDN for the prompt label.
|
||||
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
|
||||
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
|
||||
info.PeerName = fqdn
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
|
||||
meta := map[string]string{
|
||||
"peer_name": info.PeerName,
|
||||
"peer_pubkey": info.PeerPubKey,
|
||||
"source_ip": info.SourceIP,
|
||||
"mode": info.Mode,
|
||||
"username": info.Username,
|
||||
"initiator": info.Initiator,
|
||||
}
|
||||
d, err := a.broker.Request(ctx, approval.Prompt{
|
||||
Kind: approval.KindVNC,
|
||||
Subject: subject,
|
||||
Metadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return vncserver.ApprovalDecision{}, err
|
||||
}
|
||||
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
|
||||
}
|
||||
|
||||
func displayPeer(info vncserver.ApprovalInfo) string {
|
||||
if info.Initiator != "" {
|
||||
return info.Initiator
|
||||
}
|
||||
if info.PeerName != "" {
|
||||
return info.PeerName
|
||||
}
|
||||
if info.SourceIP != "" {
|
||||
return info.SourceIP
|
||||
}
|
||||
if info.PeerPubKey != "" {
|
||||
return info.PeerPubKey
|
||||
}
|
||||
return "unknown peer"
|
||||
}
|
||||
31
client/internal/engine_vnc_console_freebsd.go
Normal file
31
client/internal/engine_vnc_console_freebsd.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
|
||||
// for capture, /dev/uinput for input. The uinput device requires the
|
||||
// `uinput` kernel module (`kldload uinput`); without it, input init
|
||||
// fails and we drop to a stub injector so the user still gets a
|
||||
// view-only screen mirror.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
|
||||
}
|
||||
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
|
||||
return poller, inj, nil
|
||||
} else {
|
||||
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
}
|
||||
30
client/internal/engine_vnc_console_linux.go
Normal file
30
client/internal/engine_vnc_console_linux.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
|
||||
// without a running X server. Used as the auto-fallback when
|
||||
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
|
||||
// /dev/uinput aren't usable so the caller can drop back to a stub.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
|
||||
}
|
||||
inj, err := vncserver.NewUInputInjector(w, h)
|
||||
if err != nil {
|
||||
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
return poller, inj, nil
|
||||
}
|
||||
34
client/internal/engine_vnc_darwin.go
Normal file
34
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
// Prompt for Screen Recording at server-enable time rather than first
|
||||
// client-connect. The native prompt is far easier for users to act on
|
||||
// in the moment they toggled VNC on than later when "the screen looks
|
||||
// like wallpaper" would otherwise be the only clue.
|
||||
vncserver.PrimeScreenCapturePermission()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}, true
|
||||
}
|
||||
return capturer, injector, true
|
||||
}
|
||||
|
||||
// vncNeedsServiceMode reports whether the running process is a system
|
||||
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
|
||||
// bootstrap namespace and cannot talk to WindowServer; we route capture
|
||||
// through a per-user agent in that case.
|
||||
func vncNeedsServiceMode() bool {
|
||||
return os.Geteuid() == 0 && os.Getppid() == 1
|
||||
}
|
||||
23
client/internal/engine_vnc_stub.go
Normal file
23
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build js || ios || android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
|
||||
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error { return nil }
|
||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
35
client/internal/engine_vnc_x11.go
Normal file
35
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
|
||||
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
|
||||
injector, err := vncserver.NewX11InputInjector("", "", "")
|
||||
if err == nil {
|
||||
return vncserver.NewX11Poller("", ""), injector, true
|
||||
}
|
||||
log.Debugf("VNC: X11 not available: %v", err)
|
||||
|
||||
// Fallback for headless / pre-X states (kernel console, login manager
|
||||
// without X, physical server in recovery): stream the framebuffer and
|
||||
// inject input via /dev/uinput.
|
||||
consoleCap, consoleInj, err := newConsoleVNC()
|
||||
if err == nil {
|
||||
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
|
||||
return consoleCap, consoleInj, true
|
||||
}
|
||||
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
|
||||
|
||||
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -119,10 +119,6 @@ func (d *BindListener) ReadPackets() {
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.lazyConn.Close()
|
||||
d.bind.RemoveEndpoint(d.fakeIP)
|
||||
d.done.Done()
|
||||
|
||||
@@ -120,6 +120,36 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_vnc_traffic",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"period_seconds": tick.Period.Seconds(),
|
||||
"bytes_out": float64(tick.BytesOut),
|
||||
"writes": float64(tick.Writes),
|
||||
"fbus": float64(tick.FBUs),
|
||||
"max_fbu_bytes": float64(tick.MaxFBUBytes),
|
||||
"max_fbu_rects": float64(tick.MaxFBURects),
|
||||
"max_write_bytes": float64(tick.MaxWriteBytes),
|
||||
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -59,6 +59,11 @@ type metricsImplementation interface {
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC
|
||||
// session's wire activity. Called once per metricsConn tick interval
|
||||
// (and once at session close), only when the tick saw activity.
|
||||
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
|
||||
|
||||
// Export exports metrics in InfluxDB line protocol format
|
||||
Export(w io.Writer) error
|
||||
|
||||
@@ -78,6 +83,21 @@ type ClientMetrics struct {
|
||||
pushCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
|
||||
// tick; Max* fields are the high-water marks observed during the tick.
|
||||
// Period is the wall-clock duration the deltas cover.
|
||||
type VNCSessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
|
||||
@@ -127,6 +147,17 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
|
||||
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -73,6 +73,9 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) Export(w io.Writer) error {
|
||||
if m.exportData != "" {
|
||||
_, err := w.Write([]byte(m.exportData))
|
||||
|
||||
@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
|
||||
}
|
||||
|
||||
offer := h.buildOfferAnswer()
|
||||
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
|
||||
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||
}
|
||||
|
||||
func (h *Handshaker) sendAnswer() error {
|
||||
answer := h.buildOfferAnswer()
|
||||
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
|
||||
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
|
||||
|
||||
return h.signaler.SignalAnswer(answer, h.config.Key)
|
||||
}
|
||||
|
||||
@@ -192,6 +192,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
muxRelays sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
@@ -244,8 +245,8 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
}
|
||||
|
||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.relayMgr = manager
|
||||
}
|
||||
|
||||
@@ -906,8 +907,8 @@ func (d *Status) MarkSignalConnected() {
|
||||
}
|
||||
|
||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.muxRelays.Lock()
|
||||
defer d.muxRelays.Unlock()
|
||||
d.relayStates = relayResults
|
||||
}
|
||||
|
||||
@@ -1018,24 +1019,26 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
d.muxRelays.RLock()
|
||||
if d.relayMgr == nil {
|
||||
return d.relayStates
|
||||
defer d.muxRelays.RUnlock()
|
||||
return slices.Clone(d.relayStates)
|
||||
}
|
||||
|
||||
relayMgr := d.relayMgr
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
d.muxRelays.RUnlock()
|
||||
|
||||
states := d.relayMgr.RelayStates()
|
||||
states := relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
for _, r := range relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
@@ -1238,6 +1241,15 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
}
|
||||
}
|
||||
|
||||
// HasEventSubscribers reports whether any client is currently subscribed
|
||||
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
|
||||
// fail closed when no UI is connected to prompt the user.
|
||||
func (d *Status) HasEventSubscribers() bool {
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
return len(d.eventStreams) > 0
|
||||
}
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
if sub == nil {
|
||||
|
||||
@@ -70,6 +70,8 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -125,6 +127,8 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -433,7 +437,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
||||
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
|
||||
if *input.ServerSSHAllowed {
|
||||
log.Infof("enabling SSH server")
|
||||
} else {
|
||||
@@ -454,6 +458,33 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.False()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCApproval != nil {
|
||||
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
|
||||
if *input.DisableVNCApproval {
|
||||
log.Infof("disabling VNC connection approval prompt")
|
||||
} else {
|
||||
log.Infof("enabling VNC connection approval prompt")
|
||||
}
|
||||
config.DisableVNCApproval = input.DisableVNCApproval
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
|
||||
@@ -242,6 +242,35 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
|
||||
// Configs written before ServerSSHAllowed was introduced lack the field and
|
||||
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
|
||||
// apply the value instead of panicking on a nil pointer dereference.
|
||||
tests := []struct {
|
||||
name string
|
||||
input *bool
|
||||
want bool
|
||||
}{
|
||||
{"enable", util.True(), true},
|
||||
{"disable", util.False(), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
|
||||
|
||||
config, err := UpdateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
ServerSSHAllowed: tt.input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
|
||||
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
|
||||
@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
|
||||
// describing why a lookup failed. The OPT is stripped from the reply when
|
||||
// the original client did not request EDNS0.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
if !hadEdns {
|
||||
r.SetEdns0(dns.DefaultMsgSize, false)
|
||||
}
|
||||
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if ede, ok := resutil.ExtractEDE(reply); ok {
|
||||
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
|
||||
}
|
||||
if !hadEdns {
|
||||
resutil.StripOPT(reply)
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
|
||||
reply.Id = r.Id
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build privileged
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
@@ -3,79 +3,24 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "utun100"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "lo0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "lo0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := New(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -122,122 +67,3 @@ func TestBits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := New(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !android && !ios
|
||||
//go:build !android && !ios && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -26,11 +26,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -515,125 +510,3 @@ func setupTestEnv(t *testing.T) {
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,10 @@
|
||||
//go:build !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@@ -18,10 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
@@ -33,62 +26,6 @@ func init() {
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package systemops
|
||||
|
||||
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||
// privileged build tag) so the non-privileged test files in this package compile.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedExternalInt = "dummyext0"
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var expectedInternalInt = "dummyint0"
|
||||
@@ -0,0 +1,83 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||
// BSD/darwin test files compile without the privileged build tag.
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
@@ -20,63 +20,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
nbnet.Init()
|
||||
for _, tc := range testCases {
|
||||
@@ -102,16 +45,6 @@ func TestRouting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build windows && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux && !android && privileged
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||
|
||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||
// without v6 connectivity. If a default already exists it is left alone.
|
||||
//
|
||||
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
@@ -36,6 +36,7 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use a cancellable context so Stop() can abort an in-progress interactive
|
||||
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
|
||||
// bound to a port) until the OAuth callback arrives or the flow expires;
|
||||
// cancelling the context unblocks WaitToken, which then shuts that server down
|
||||
// and frees the port for the next login attempt. iOS runs login in the main-app
|
||||
// process (decoupled from the network extension), so without this the server
|
||||
// lingers after the user dismisses the browser and the next connect stalls
|
||||
// trying to bind the same port.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Auth{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: cfg,
|
||||
cfgPath: cfgPath,
|
||||
}, nil
|
||||
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
|
||||
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
|
||||
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
|
||||
// safe to call when no login is running.
|
||||
func (a *Auth) Stop() {
|
||||
if a.cancel != nil {
|
||||
a.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -121,6 +121,14 @@ service DaemonService {
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
|
||||
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -207,6 +215,10 @@ message LoginRequest {
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
optional bool disable_ipv6 = 40;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
|
||||
optional bool disableVNCApproval = 42;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -317,12 +329,16 @@ message GetConfigResponse {
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
|
||||
bool disableVNCApproval = 29;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 28;
|
||||
repeated string mDMManagedFields = 30;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -407,6 +423,25 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCSessionInfo contains information about an active VNC session
|
||||
message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
// initiator is the human-readable display name of the dashboard user
|
||||
// who minted the SessionPubKey, when known.
|
||||
string initiator = 5;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
repeated VNCSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -421,6 +456,7 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -609,6 +645,7 @@ message SystemEvent {
|
||||
AUTHENTICATION = 2;
|
||||
CONNECTIVITY = 3;
|
||||
SYSTEM = 4;
|
||||
APPROVAL = 5;
|
||||
}
|
||||
|
||||
string id = 1;
|
||||
@@ -699,6 +736,10 @@ message SetConfigRequest {
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
optional bool disable_ipv6 = 35;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
|
||||
optional bool disableVNCApproval = 37;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -929,3 +970,18 @@ message StartBundleCaptureRequest {
|
||||
message StartBundleCaptureResponse {}
|
||||
message StopBundleCaptureRequest {}
|
||||
message StopBundleCaptureResponse {}
|
||||
|
||||
message RespondApprovalRequest {
|
||||
// request_id matches the SystemEvent metadata key emitted by the daemon
|
||||
// when a subsystem awaits user approval for an inbound connection.
|
||||
string request_id = 1;
|
||||
// accept is true if the user approved the request, false if they
|
||||
// denied it. A missing or unknown request_id is treated as a no-op.
|
||||
bool accept = 2;
|
||||
// view_only signals that the user granted the connection but withheld
|
||||
// input control. Only meaningful when accept is true; ignored when
|
||||
// accept is false.
|
||||
bool view_only = 3;
|
||||
}
|
||||
|
||||
message RespondApprovalResponse {}
|
||||
|
||||
@@ -59,6 +59,7 @@ const (
|
||||
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
|
||||
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
|
||||
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
|
||||
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
|
||||
)
|
||||
|
||||
// DaemonServiceClient is the client API for DaemonService service.
|
||||
@@ -136,6 +137,13 @@ type DaemonServiceClient interface {
|
||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -573,6 +581,16 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
|
||||
|
||||
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RespondApprovalResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility.
|
||||
@@ -648,6 +666,13 @@ type DaemonServiceServer interface {
|
||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -778,6 +803,9 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
|
||||
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
|
||||
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
@@ -1498,6 +1526,24 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
|
||||
|
||||
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RespondApprovalRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_RespondApproval_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1653,6 +1699,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetInstallerResult",
|
||||
Handler: _DaemonService_GetInstallerResult_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RespondApproval",
|
||||
Handler: _DaemonService_RespondApproval_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
|
||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
engine, err := s.claimCapture(sess)
|
||||
engine, err := s.claimCapture(sess, func() { pw.Close() })
|
||||
if err != nil {
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
@@ -190,10 +190,7 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
s.cleanupBundleCapture()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
s.evictActiveCaptureLocked()
|
||||
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
@@ -304,29 +301,58 @@ func (s *Server) cleanupBundleCapture() {
|
||||
s.bundleCapture = nil
|
||||
}
|
||||
|
||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
||||
// FailedPrecondition if another capture is already active.
|
||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
||||
// claimCapture reserves the engine's capture slot for sess. If another
|
||||
// capture is already running it is evicted: a previous streaming session
|
||||
// whose gRPC client died and never freed the slot stays stuck otherwise,
|
||||
// and a bundle capture is just informational state.
|
||||
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
s.evictActiveCaptureLocked()
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.activeCapture = sess
|
||||
s.activeCaptureCancel = cancel
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// evictActiveCaptureLocked tears down whatever capture currently owns
|
||||
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
|
||||
func (s *Server) evictActiveCaptureLocked() {
|
||||
if s.activeCapture == nil {
|
||||
return
|
||||
}
|
||||
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
|
||||
log.Infof("evicting running bundle capture to start a new capture")
|
||||
s.stopBundleCaptureLocked()
|
||||
return
|
||||
}
|
||||
log.Infof("evicting previous streaming capture to start a new one")
|
||||
prev := s.activeCapture
|
||||
cancel := s.activeCaptureCancel
|
||||
if engine, err := s.getCaptureEngineLocked(); err == nil {
|
||||
if err := engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear previous capture: %v", err)
|
||||
}
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
prev.Stop()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture == sess {
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +367,7 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
log.Debugf("clear capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
|
||||
@@ -100,8 +100,12 @@ type Server struct {
|
||||
captureEnabled bool
|
||||
bundleCapture *bundleCapture
|
||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
||||
activeCapture *capture.Session
|
||||
networksDisabled bool
|
||||
activeCapture *capture.Session
|
||||
// activeCaptureCancel tears down the streaming pipe/cancel for the
|
||||
// active streaming capture so eviction unblocks the StartCapture RPC
|
||||
// handler. Nil for bundle captures (they own their own context).
|
||||
activeCaptureCancel func()
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
@@ -456,6 +460,8 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.DisableVNCApproval = msg.DisableVNCApproval
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -993,6 +999,10 @@ func (s *Server) cleanupConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
|
||||
// actCancel() lets the run loop stop the engine too, so both stop it
|
||||
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
|
||||
// making the run loop the sole owner of engine shutdown.
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return err
|
||||
@@ -1247,6 +1257,7 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1286,6 +1297,38 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetVNCServerStatus()
|
||||
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
UserID: sess.UserID,
|
||||
Initiator: sess.Initiator,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
Enabled: enabled,
|
||||
Sessions: pbSessions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1526,6 +1569,27 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's accept/deny decision for a pending
|
||||
// approval prompt to the engine's broker. Unknown or already-resolved
|
||||
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
|
||||
// user already handled (or that already timed out).
|
||||
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
if connectClient == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
|
||||
}
|
||||
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
|
||||
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
|
||||
}
|
||||
return &proto.RespondApprovalResponse{}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
@@ -1641,6 +1705,8 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
|
||||
235
client/server/server_privileged_test.go
Normal file
235
client/server/server_privileged_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
//go:build privileged
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
@@ -2,124 +2,22 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
|
||||
counter := 0
|
||||
// start the management server
|
||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start management server: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||
defer cancel()
|
||||
// create new server
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: "http://" + mgmtAddr,
|
||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||
t.Setenv(retryInitialIntervalVar, "1s")
|
||||
t.Setenv(maxRetryIntervalVar, "2s")
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Up(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
@@ -259,119 +157,3 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mgmtProto.ManagementServiceServer
|
||||
counter *int
|
||||
}
|
||||
|
||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
*m.counter++
|
||||
return m.ManagementServiceServer.Login(ctx, req)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
|
||||
config := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mock := &mockServer{
|
||||
ManagementServiceServer: mgmtServer,
|
||||
counter: counter,
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user