mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-03 05:09:54 +00:00
Compare commits
42 Commits
client-jso
...
backport/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
415eb59ef5 | ||
|
|
844881f5cf | ||
|
|
7d4736de55 | ||
|
|
06839a4731 | ||
|
|
eb422a5cd3 | ||
|
|
0aa0f7c76b | ||
|
|
7c0d8cbae0 | ||
|
|
2ab99eefa6 | ||
|
|
ff04ffb534 | ||
|
|
980598ed4a | ||
|
|
92a66cdd20 | ||
|
|
3be90f06b2 | ||
|
|
4ef65294e9 | ||
|
|
5b5f11740a | ||
|
|
3de889d529 | ||
|
|
04c3d19032 | ||
|
|
3f1fb3b52d | ||
|
|
b434cda062 | ||
|
|
0b594c639a | ||
|
|
deff8af59f | ||
|
|
5711f0e38c | ||
|
|
1409a1325a | ||
|
|
4400372f37 | ||
|
|
2d7b309004 | ||
|
|
5968cff242 | ||
|
|
cf43841b86 | ||
|
|
739e36a313 | ||
|
|
2bb5421631 | ||
|
|
998ade6e6d | ||
|
|
62f5467cd8 | ||
|
|
1b29995ece | ||
|
|
fd96b8c12f | ||
|
|
6dd6c3f398 | ||
|
|
d1422dcf09 | ||
|
|
615631567a | ||
|
|
f4daf59bcd | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
a400828b89 | ||
|
|
e2bb328a34 | ||
|
|
221b9c012c |
68
.github/workflows/agent-network-e2e.yml
vendored
Normal file
68
.github/workflows/agent-network-e2e.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
name: Agent Network E2E
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Nightly at 03:00 UTC, plus on demand from the Actions tab.
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * *"
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
e2e:
|
||||||
|
name: Agent Network E2E
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 45
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
|
# Container-driver builder so the harness can build the combined/proxy/
|
||||||
|
# client images from source with a local layer cache.
|
||||||
|
- name: Set up Buildx
|
||||||
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||||
|
|
||||||
|
# Persist the Docker layer cache across runs. This caches the base, apt,
|
||||||
|
# and go-mod-download layers; the Go compile still re-runs, as BuildKit
|
||||||
|
# mount caches cannot be exported to the GitHub cache.
|
||||||
|
- name: Cache Docker layers
|
||||||
|
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||||
|
with:
|
||||||
|
path: /tmp/.buildx-cache
|
||||||
|
key: ${{ runner.os }}-anet-e2e-buildx-${{ hashFiles('go.sum', 'combined/Dockerfile.multistage', 'proxy/Dockerfile.multistage', 'e2e/harness/Dockerfile.client') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-anet-e2e-buildx-
|
||||||
|
|
||||||
|
- name: Run agent-network e2e
|
||||||
|
env:
|
||||||
|
# Build the images from source (this branch's code) with the shared
|
||||||
|
# local layer cache.
|
||||||
|
NB_E2E_BUILDX_CACHE: /tmp/.buildx-cache
|
||||||
|
# Provider credentials. Each provider scenario skips if its
|
||||||
|
# token (and URL, for gateways) is unset, so partial coverage is fine.
|
||||||
|
OPENAI_TOKEN: ${{ secrets.E2E_OPENAI_TOKEN }}
|
||||||
|
ANTHROPIC_TOKEN: ${{ secrets.E2E_ANTHROPIC_TOKEN }}
|
||||||
|
VERCEL_URL: ${{ secrets.E2E_VERCEL_URL }}
|
||||||
|
VERCEL_TOKEN: ${{ secrets.E2E_VERCEL_TOKEN }}
|
||||||
|
OPENROUTER_URL: ${{ secrets.E2E_OPENROUTER_URL }}
|
||||||
|
OPENROUTER_TOKEN: ${{ secrets.E2E_OPENROUTER_TOKEN }}
|
||||||
|
CLOUDFLARE_URL: ${{ secrets.E2E_CLOUDFLARE_URL }}
|
||||||
|
CLOUDFLARE_TOKEN: ${{ secrets.E2E_CLOUDFLARE_TOKEN }}
|
||||||
|
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.E2E_AWS_BEARER_TOKEN_BEDROCK }}
|
||||||
|
AWS_REGION: ${{ secrets.E2E_AWS_REGION }}
|
||||||
|
# Vertex (Anthropic-on-Vertex): SA + project required; region defaults
|
||||||
|
# to "global", model to a pinned claude snapshot.
|
||||||
|
GOOGLE_VERTEX_SA_BASE64: ${{ secrets.E2E_GOOGLE_VERTEX_SA_BASE64 }}
|
||||||
|
GOOGLE_VERTEX_PROJECT: ${{ secrets.E2E_GOOGLE_VERTEX_PROJECT }}
|
||||||
|
GOOGLE_VERTEX_REGION: ${{ secrets.E2E_GOOGLE_VERTEX_REGION }}
|
||||||
|
GOOGLE_VERTEX_MODEL: ${{ secrets.E2E_GOOGLE_VERTEX_MODEL }}
|
||||||
|
run: go test -tags e2e -timeout 40m -v ./e2e/...
|
||||||
@@ -64,7 +64,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: true
|
cache: true
|
||||||
|
|||||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -21,13 +21,13 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
|
||||||
|
|||||||
20
.github/workflows/golang-test-freebsd.yml
vendored
20
.github/workflows/golang-test-freebsd.yml
vendored
@@ -48,14 +48,14 @@ jobs:
|
|||||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -tags privileged -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -v -p 1 ./client/...
|
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -tags privileged -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -tags privileged -timeout 1m -failfast ./encryption/...
|
||||||
time go test -timeout 1m -failfast ./formatter/...
|
time go test -tags privileged -timeout 1m -failfast ./formatter/...
|
||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -tags privileged -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -tags privileged -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -tags privileged -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
50
.github/workflows/golang-test-linux.yml
vendored
50
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- 'management/**'
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -41,7 +41,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -135,7 +135,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -158,7 +158,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
@@ -180,7 +180,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -192,7 +192,7 @@ jobs:
|
|||||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
id: cache-restore
|
id: cache-restore
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -229,7 +229,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -251,7 +251,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -266,7 +266,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -311,7 +311,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -325,7 +325,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -368,7 +368,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -383,7 +383,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -429,7 +429,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -440,7 +440,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -534,7 +534,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -545,7 +545,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -579,10 +579,11 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
@@ -628,7 +629,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -639,7 +640,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -673,12 +674,13 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
GIT_BRANCH=${{ github.ref_name }} \
|
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/server/http/...
|
-timeout 20m ./management/server/http/...
|
||||||
|
env:
|
||||||
|
GIT_BRANCH: ${{ github.ref_name }}
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
name: "Management / Integration"
|
name: "Management / Integration"
|
||||||
@@ -697,7 +699,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -708,7 +710,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
|
|||||||
6
.github/workflows/golang-test-windows.yml
vendored
6
.github/workflows/golang-test-windows.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
@@ -35,7 +35,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||||
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||||
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||||
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals,flate,recordin,unparseable
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
@@ -28,13 +28,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
cmdline-tools-version: 8512546
|
cmdline-tools-version: 8512546
|
||||||
- name: Setup Java
|
- name: Setup Java
|
||||||
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
|
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
|
||||||
with:
|
with:
|
||||||
java-version: "11"
|
java-version: "11"
|
||||||
distribution: "adopt"
|
distribution: "adopt"
|
||||||
- name: NDK Cache
|
- name: NDK Cache
|
||||||
id: ndk-cache
|
id: ndk-cache
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: /usr/local/lib/android/sdk/ndk
|
path: /usr/local/lib/android/sdk/ndk
|
||||||
key: ndk-cache-23.1.7779620
|
key: ndk-cache-23.1.7779620
|
||||||
@@ -58,7 +58,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
|
|||||||
12
.github/workflows/release.yml
vendored
12
.github/workflows/release.yml
vendored
@@ -166,12 +166,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -374,12 +374,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -469,12 +469,12 @@ jobs:
|
|||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
|||||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
|
|||||||
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: lint lint-all lint-install setup-hooks
|
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
|
||||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||||
|
|
||||||
# Install golangci-lint locally if needed
|
# Install golangci-lint locally if needed
|
||||||
@@ -25,3 +25,15 @@ setup-hooks:
|
|||||||
@git config core.hooksPath .githooks
|
@git config core.hooksPath .githooks
|
||||||
@chmod +x .githooks/pre-push
|
@chmod +x .githooks/pre-push
|
||||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||||
|
|
||||||
|
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
|
||||||
|
# Runs as a normal user with no sudo and leaves host networking untouched.
|
||||||
|
test-unit:
|
||||||
|
@go test -tags devcert -timeout 10m ./...
|
||||||
|
|
||||||
|
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
|
||||||
|
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
|
||||||
|
# Narrow the run with env vars, e.g.:
|
||||||
|
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
|
||||||
|
test-privileged:
|
||||||
|
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...
|
||||||
|
|||||||
@@ -33,10 +33,15 @@
|
|||||||
<br/>
|
<br/>
|
||||||
<br/>
|
<br/>
|
||||||
<strong>
|
<strong>
|
||||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
|
||||||
</strong>
|
</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
> ### 🤖 NetBird Agent Network (Beta)
|
||||||
|
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
|
||||||
|
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
|
||||||
|
> read the docs at **[netbird.ai](https://netbird.ai)**.
|
||||||
|
|
||||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||||
|
|
||||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||||
|
|||||||
39
agent-network/README.md
Normal file
39
agent-network/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# NetBird Agent Network
|
||||||
|
|
||||||
|
Agent Network is NetBird's access control layer for AI agents and the people who run
|
||||||
|
them. It gives every agent a real identity, tied to your identity provider (IdP), and
|
||||||
|
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
|
||||||
|
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
|
||||||
|
policy, with no API keys to leak.
|
||||||
|
|
||||||
|
> **Beta.** Agent Network is open source and can be self-hosted on your own
|
||||||
|
> infrastructure.
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
Agent Network is built on two existing NetBird capabilities:
|
||||||
|
|
||||||
|
- **Overlay network** — the encrypted WireGuard mesh between peers.
|
||||||
|
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
|
||||||
|
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
|
||||||
|
key server-side, forwards to the API or gateway, and records usage.
|
||||||
|
|
||||||
|
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
|
||||||
|
resources (databases, internal APIs, self-hosted models) are reached directly over
|
||||||
|
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
|
||||||
|
|
||||||
|
## Where the code lives
|
||||||
|
|
||||||
|
There is no separate "agent-network" service — it reuses the reverse-proxy and management
|
||||||
|
components:
|
||||||
|
|
||||||
|
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
|
||||||
|
and runs the per-request middleware pipeline.
|
||||||
|
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
|
||||||
|
— the management-side control plane: providers, policies, guardrails, limits, routing,
|
||||||
|
and usage/access logs.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Full documentation, architecture, and quickstart:
|
||||||
|
**https://docs.netbird.io/agent-network**
|
||||||
196
client/cmd/service_privileged_test.go
Normal file
196
client/cmd/service_privileged_test.go
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/kardianos/service"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serviceStartTimeout = 10 * time.Second
|
||||||
|
serviceStopTimeout = 5 * time.Second
|
||||||
|
statusPollInterval = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||||
|
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer timeoutCancel()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(statusPollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||||
|
case <-ticker.C:
|
||||||
|
status, err := s.Status()
|
||||||
|
if err != nil {
|
||||||
|
// Continue polling on transient errors
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if status == expectedStatus {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceLifecycle tests the complete service lifecycle
|
||||||
|
func TestServiceLifecycle(t *testing.T) {
|
||||||
|
// TODO: Add support for Windows and macOS
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Getenv("CONTAINER") == "true" {
|
||||||
|
t.Skip("Skipping service lifecycle test in container environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
originalServiceName := serviceName
|
||||||
|
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||||
|
defer func() {
|
||||||
|
serviceName = originalServiceName
|
||||||
|
}()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||||
|
|
||||||
|
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service config: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cleanup: create service: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the subtests already cleaned up, there's nothing to do.
|
||||||
|
if _, err := s.Status(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.Stop(); err != nil {
|
||||||
|
t.Errorf("cleanup: stop service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.Uninstall(); err != nil {
|
||||||
|
t.Errorf("cleanup: uninstall service: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("Install", func(t *testing.T) {
|
||||||
|
installCmd.SetContext(ctx)
|
||||||
|
err := installCmd.RunE(installCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
status, err := s.Status()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEqual(t, service.StatusUnknown, status)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Start", func(t *testing.T) {
|
||||||
|
startCmd.SetContext(ctx)
|
||||||
|
err := startCmd.RunE(startCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Restart", func(t *testing.T) {
|
||||||
|
restartCmd.SetContext(ctx)
|
||||||
|
err := restartCmd.RunE(restartCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Reconfigure", func(t *testing.T) {
|
||||||
|
originalLogLevel := logLevel
|
||||||
|
logLevel = "debug"
|
||||||
|
defer func() {
|
||||||
|
logLevel = originalLogLevel
|
||||||
|
}()
|
||||||
|
|
||||||
|
reconfigureCmd.SetContext(ctx)
|
||||||
|
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, running)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stop", func(t *testing.T) {
|
||||||
|
stopCmd.SetContext(ctx)
|
||||||
|
err := stopCmd.RunE(stopCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, stopped)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Uninstall", func(t *testing.T) {
|
||||||
|
uninstallCmd.SetContext(ctx)
|
||||||
|
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := newSVCConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.Status()
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
serviceStartTimeout = 10 * time.Second
|
|
||||||
serviceStopTimeout = 5 * time.Second
|
|
||||||
statusPollInterval = 500 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
|
||||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer timeoutCancel()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(statusPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
|
||||||
case <-ticker.C:
|
|
||||||
status, err := s.Status()
|
|
||||||
if err != nil {
|
|
||||||
// Continue polling on transient errors
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if status == expectedStatus {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceLifecycle tests the complete service lifecycle
|
|
||||||
func TestServiceLifecycle(t *testing.T) {
|
|
||||||
// TODO: Add support for Windows and macOS
|
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
|
||||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.Getenv("CONTAINER") == "true" {
|
|
||||||
t.Skip("Skipping service lifecycle test in container environment")
|
|
||||||
}
|
|
||||||
|
|
||||||
originalServiceName := serviceName
|
|
||||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
|
||||||
defer func() {
|
|
||||||
serviceName = originalServiceName
|
|
||||||
}()
|
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
|
||||||
logLevel = "info"
|
|
||||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
|
||||||
|
|
||||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service config: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("cleanup: create service: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the subtests already cleaned up, there's nothing to do.
|
|
||||||
if _, err := s.Status(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Stop(); err != nil {
|
|
||||||
t.Errorf("cleanup: stop service: %v", err)
|
|
||||||
}
|
|
||||||
if err := s.Uninstall(); err != nil {
|
|
||||||
t.Errorf("cleanup: uninstall service: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
t.Run("Install", func(t *testing.T) {
|
|
||||||
installCmd.SetContext(ctx)
|
|
||||||
err := installCmd.RunE(installCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
status, err := s.Status()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotEqual(t, service.StatusUnknown, status)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Start", func(t *testing.T) {
|
|
||||||
startCmd.SetContext(ctx)
|
|
||||||
err := startCmd.RunE(startCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Restart", func(t *testing.T) {
|
|
||||||
restartCmd.SetContext(ctx)
|
|
||||||
err := restartCmd.RunE(restartCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Reconfigure", func(t *testing.T) {
|
|
||||||
originalLogLevel := logLevel
|
|
||||||
logLevel = "debug"
|
|
||||||
defer func() {
|
|
||||||
logLevel = originalLogLevel
|
|
||||||
}()
|
|
||||||
|
|
||||||
reconfigureCmd.SetContext(ctx)
|
|
||||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, running)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Stop", func(t *testing.T) {
|
|
||||||
stopCmd.SetContext(ctx)
|
|
||||||
err := stopCmd.RunE(stopCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, stopped)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Uninstall", func(t *testing.T) {
|
|
||||||
uninstallCmd.SetContext(ctx)
|
|
||||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.Status()
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestServiceEnvVars tests environment variable parsing
|
// TestServiceEnvVars tests environment variable parsing
|
||||||
func TestServiceEnvVars(t *testing.T) {
|
func TestServiceEnvVars(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android
|
//go:build !android && privileged
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -136,6 +136,11 @@ func (p *ProxyBind) CloseConn() error {
|
|||||||
return p.close()
|
return p.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
|
||||||
|
func (p *ProxyBind) InjectPacket(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) close() error {
|
func (p *ProxyBind) close() error {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -219,6 +219,17 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
|||||||
p.pausedCond.L.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||||
|
func (p *ProxyWrapper) InjectPacket(b []byte) error {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return errors.New("proxy not started")
|
||||||
|
}
|
||||||
|
if _, err := p.remoteConn.Write(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
func (p *ProxyWrapper) CloseConn() error {
|
func (p *ProxyWrapper) CloseConn() error {
|
||||||
if p.cancel == nil {
|
if p.cancel == nil {
|
||||||
|
|||||||
@@ -18,4 +18,9 @@ type Proxy interface {
|
|||||||
RedirectAs(endpoint *net.UDPAddr)
|
RedirectAs(endpoint *net.UDPAddr)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
SetDisconnectListener(disconnected func())
|
SetDisconnectListener(disconnected func())
|
||||||
|
|
||||||
|
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
|
||||||
|
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
|
||||||
|
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
|
||||||
|
InjectPacket(b []byte) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux
|
//go:build !linux || !privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
|
|||||||
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
|
||||||
wgPort := 51850
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("192.168.0.56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
|
||||||
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
|
||||||
wgPort := 51851
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
|
||||||
t.Errorf("failed to free ebpf proxy: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
// NetBird UDP address of the remote peer
|
|
||||||
nbAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("100.108.111.177"),
|
|
||||||
Port: 38746,
|
|
||||||
}
|
|
||||||
|
|
||||||
p2pEndpoint := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("fe80::56"),
|
|
||||||
Port: 51820,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
|
||||||
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
func TestRedirectAs_UDP_IPv4(t *testing.T) {
|
||||||
wgPort := 51852
|
wgPort := 51852
|
||||||
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
|
||||||
|
wgPort := 51850
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("192.168.0.56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
|
||||||
|
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
|
||||||
|
wgPort := 51851
|
||||||
|
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
t.Fatalf("failed to initialize ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := ebpfProxy.Free(); err != nil {
|
||||||
|
t.Errorf("failed to free ebpf proxy: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
|
|
||||||
|
// NetBird UDP address of the remote peer
|
||||||
|
nbAddr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("100.108.111.177"),
|
||||||
|
Port: 38746,
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpoint := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("fe80::56"),
|
||||||
|
Port: 51820,
|
||||||
|
}
|
||||||
|
|
||||||
|
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
|
||||||
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
func TestRedirectAs_Multiple_Switches(t *testing.T) {
|
||||||
wgPort := 51856
|
wgPort := 51856
|
||||||
|
|||||||
@@ -147,6 +147,17 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
|||||||
p.sendPkg = p.srcFakerConn.SendPkg
|
p.sendPkg = p.srcFakerConn.SendPkg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||||
|
func (p *WGUDPProxy) InjectPacket(b []byte) error {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return errors.New("proxy not started")
|
||||||
|
}
|
||||||
|
if _, err := p.remoteConn.Write(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
func (p *WGUDPProxy) CloseConn() error {
|
func (p *WGUDPProxy) CloseConn() error {
|
||||||
if p.cancel == nil {
|
if p.cancel == nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
@@ -30,11 +31,13 @@ type Manager interface {
|
|||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
ipsetCounter int
|
ipsetCounter int
|
||||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||||
routeRules map[id.RuleID]struct{}
|
routeRules map[id.RuleID]struct{}
|
||||||
mutex sync.Mutex
|
previousConfigHash uint64
|
||||||
|
hasAppliedConfig bool
|
||||||
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
@@ -57,6 +60,23 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip the full rebuild + flush when the inputs that drive the firewall
|
||||||
|
// state are byte-for-byte identical to the last successfully applied
|
||||||
|
// update. Management re-sends the same network map far more often than it
|
||||||
|
// actually changes (account-wide updates, peer meta churn), and rebuilding
|
||||||
|
// every peer/route ACL and flushing the firewall on every such sync is the
|
||||||
|
// dominant client-side cost when nothing changed. Mirrors the same guard the
|
||||||
|
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
|
||||||
|
// consumes participate in the hash, so an unrelated map change cannot mask a
|
||||||
|
// real ACL change.
|
||||||
|
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
|
||||||
|
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
@@ -70,13 +90,49 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
if routeErr != nil {
|
||||||
|
log.Errorf("Failed to apply route ACLs: %v", routeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.firewall.Flush(); err != nil {
|
flushErr := d.firewall.Flush()
|
||||||
log.Error("failed to flush firewall rules: ", err)
|
if flushErr != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", flushErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only remember the hash once the firewall actually reflects this config.
|
||||||
|
// If applying or flushing failed, leave the previous hash untouched so the
|
||||||
|
// next (possibly identical) update is not skipped and gets a chance to
|
||||||
|
// reconcile the firewall state.
|
||||||
|
if err == nil && routeErr == nil && flushErr == nil {
|
||||||
|
d.previousConfigHash = hash
|
||||||
|
d.hasAppliedConfig = true
|
||||||
|
} else {
|
||||||
|
d.hasAppliedConfig = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
|
||||||
|
// firewall state, so an identical hash means an identical resulting ruleset.
|
||||||
|
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
|
||||||
|
return hashstructure.Hash(struct {
|
||||||
|
PeerRules []*mgmProto.FirewallRule
|
||||||
|
PeerRulesIsEmpty bool
|
||||||
|
RouteRules []*mgmProto.RouteFirewallRule
|
||||||
|
RouteRulesIsEmpty bool
|
||||||
|
DNSRouteFeatureFlag bool
|
||||||
|
}{
|
||||||
|
PeerRules: networkMap.GetFirewallRules(),
|
||||||
|
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
|
||||||
|
RouteRules: networkMap.GetRoutesFirewallRules(),
|
||||||
|
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
|
||||||
|
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
|
||||||
|
}, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -485,3 +486,149 @@ func TestPortInfoEmpty(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
|
||||||
|
// re-applied is recognized as a no-op (hash unchanged), while a real change to
|
||||||
|
// any firewall-relevant input forces a re-apply (hash changes). This is the
|
||||||
|
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
|
||||||
|
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
|
||||||
|
firstHash := acl.previousConfigHash
|
||||||
|
require.NotZero(t, firstHash)
|
||||||
|
|
||||||
|
// Re-applying the identical map must not change the recorded hash: the
|
||||||
|
// expensive rebuild path was skipped.
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, firstHash, acl.previousConfigHash,
|
||||||
|
"identical re-apply must be a no-op (hash unchanged)")
|
||||||
|
|
||||||
|
// A real change must produce a different hash and re-apply.
|
||||||
|
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.NotEqual(t, firstHash, acl.previousConfigHash,
|
||||||
|
"changing a rule's action must force a re-apply (hash changed)")
|
||||||
|
|
||||||
|
// The dnsRouteFeatureFlag also participates in the hash.
|
||||||
|
changedHash := acl.previousConfigHash
|
||||||
|
acl.ApplyFiltering(networkMap, true)
|
||||||
|
assert.NotEqual(t, changedHash, acl.previousConfigHash,
|
||||||
|
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
|
||||||
|
nm := &mgmProto.NetworkMap{
|
||||||
|
FirewallRulesIsEmpty: peerRules == 0,
|
||||||
|
RoutesFirewallRulesIsEmpty: routeRules == 0,
|
||||||
|
}
|
||||||
|
for i := range peerRules {
|
||||||
|
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
|
||||||
|
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: fmt.Sprintf("%d", 1024+i%64511),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for i := range routeRules {
|
||||||
|
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
|
||||||
|
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
|
||||||
|
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nm
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
|
||||||
|
d := &DefaultManager{}
|
||||||
|
nm := buildNetworkMap(10, 5)
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = d.firewallConfigHash(nm, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
|
||||||
|
d := &DefaultManager{}
|
||||||
|
nm := buildNetworkMap(100, 50)
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = d.firewallConfigHash(nm, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
|
||||||
|
d := &DefaultManager{}
|
||||||
|
nm := buildNetworkMap(1000, 200)
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = d.firewallConfigHash(nm, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
|
||||||
|
// inputs and order-independent for the rule slices (management does not
|
||||||
|
// guarantee rule order).
|
||||||
|
func TestFirewallConfigHashDeterministic(t *testing.T) {
|
||||||
|
d := &DefaultManager{}
|
||||||
|
|
||||||
|
nm1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
|
||||||
|
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Same rules, reversed order.
|
||||||
|
nm2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
nm1.FirewallRules[1],
|
||||||
|
nm1.FirewallRules[0],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h1, err := d.firewallConfigHash(nm1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
h2, err := d.firewallConfigHash(nm2, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
|
||||||
|
}
|
||||||
|
|||||||
@@ -314,6 +314,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
||||||
c.statusRecorder.MarkManagementConnected()
|
c.statusRecorder.MarkManagementConnected()
|
||||||
|
|
||||||
|
if metricsConfig := loginResp.GetNetbirdConfig().GetMetrics(); metricsConfig != nil {
|
||||||
|
c.clientMetrics.UpdatePushFromMgm(c.ctx, metricsConfig.GetEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
@@ -399,6 +403,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
StateManager: stateManager,
|
StateManager: stateManager,
|
||||||
UpdateManager: c.updateManager,
|
UpdateManager: c.updateManager,
|
||||||
ClientMetrics: c.clientMetrics,
|
ClientMetrics: c.clientMetrics,
|
||||||
|
MetricsCtx: c.ctx,
|
||||||
}, mobileDependency)
|
}, mobileDependency)
|
||||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
c.engine = engine
|
c.engine = engine
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -167,7 +168,10 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
|
|||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
alternativeNetwork = "ip6"
|
alternativeNetwork = "ip6"
|
||||||
default:
|
default:
|
||||||
return dns.RcodeNameError
|
// Non-address types reach LookupIP only unexpectedly; without an
|
||||||
|
// address pair to probe we cannot prove the name is absent, so answer
|
||||||
|
// NODATA rather than a poisoning NXDOMAIN.
|
||||||
|
return dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||||
@@ -184,6 +188,230 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
|
|||||||
return dns.RcodeSuccess
|
return dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordResolver is the host resolver surface used to forward non-address
|
||||||
|
// record queries. net.DefaultResolver satisfies it.
|
||||||
|
type RecordResolver interface {
|
||||||
|
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
|
||||||
|
LookupTXT(ctx context.Context, name string) ([]string, error)
|
||||||
|
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
|
||||||
|
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
|
||||||
|
LookupCNAME(ctx context.Context, host string) (string, error)
|
||||||
|
LookupAddr(ctx context.Context, addr string) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupRecords resolves a non-address DNS record type through the host
|
||||||
|
// resolver and returns the resource records and the DNS rcode. Types the host
|
||||||
|
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
|
||||||
|
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
|
||||||
|
// for an unsupported type.
|
||||||
|
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
|
||||||
|
fqdn := dns.Fqdn(name)
|
||||||
|
|
||||||
|
switch qtype {
|
||||||
|
case dns.TypeMX:
|
||||||
|
return lookupMX(ctx, r, name, fqdn, ttl)
|
||||||
|
case dns.TypeTXT:
|
||||||
|
return lookupTXT(ctx, r, name, fqdn, ttl)
|
||||||
|
case dns.TypeNS:
|
||||||
|
return lookupNS(ctx, r, name, fqdn, ttl)
|
||||||
|
case dns.TypeSRV:
|
||||||
|
return lookupSRV(ctx, r, name, fqdn, ttl)
|
||||||
|
case dns.TypeCNAME:
|
||||||
|
return lookupCNAME(ctx, r, name, fqdn, ttl)
|
||||||
|
case dns.TypePTR:
|
||||||
|
return lookupPTR(ctx, r, name, fqdn, ttl)
|
||||||
|
default:
|
||||||
|
return nil, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
|
||||||
|
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
recs, err := r.LookupMX(ctx, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
rrs := make([]dns.RR, 0, len(recs))
|
||||||
|
for _, mx := range recs {
|
||||||
|
rrs = append(rrs, &dns.MX{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
|
||||||
|
Preference: mx.Pref,
|
||||||
|
Mx: dns.Fqdn(mx.Host),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return rrs, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
recs, err := r.LookupTXT(ctx, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
rrs := make([]dns.RR, 0, len(recs))
|
||||||
|
for _, txt := range recs {
|
||||||
|
rrs = append(rrs, &dns.TXT{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
|
||||||
|
Txt: chunkTXT(txt),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return rrs, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
recs, err := r.LookupNS(ctx, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
rrs := make([]dns.RR, 0, len(recs))
|
||||||
|
for _, ns := range recs {
|
||||||
|
rrs = append(rrs, &dns.NS{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
|
||||||
|
Ns: dns.Fqdn(ns.Host),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return rrs, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
_, recs, err := r.LookupSRV(ctx, "", "", name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
rrs := make([]dns.RR, 0, len(recs))
|
||||||
|
for _, srv := range recs {
|
||||||
|
rrs = append(rrs, &dns.SRV{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
|
||||||
|
Priority: srv.Priority,
|
||||||
|
Weight: srv.Weight,
|
||||||
|
Port: srv.Port,
|
||||||
|
Target: dns.Fqdn(srv.Target),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return rrs, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
cname, err := r.LookupCNAME(ctx, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
// LookupCNAME returns the queried name itself when the name resolves but
|
||||||
|
// has no CNAME record; that is a NODATA result, not a CNAME.
|
||||||
|
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
|
||||||
|
return nil, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
return []dns.RR{&dns.CNAME{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
|
||||||
|
Target: dns.Fqdn(cname),
|
||||||
|
}}, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
|
||||||
|
addr, ok := ptrQueryAddr(name)
|
||||||
|
if !ok {
|
||||||
|
return nil, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
names, err := r.LookupAddr(ctx, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, rcodeForRecordError(err)
|
||||||
|
}
|
||||||
|
rrs := make([]dns.RR, 0, len(names))
|
||||||
|
for _, n := range names {
|
||||||
|
rrs = append(rrs, &dns.PTR{
|
||||||
|
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
|
||||||
|
Ptr: dns.Fqdn(n),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return rrs, dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
|
||||||
|
// into the address string expected by net.Resolver.LookupAddr. It reports false
|
||||||
|
// when the name is not a well-formed reverse name.
|
||||||
|
func ptrQueryAddr(qname string) (string, bool) {
|
||||||
|
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(name, ".in-addr.arpa"):
|
||||||
|
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
|
||||||
|
case strings.HasSuffix(name, ".ip6.arpa"):
|
||||||
|
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
|
||||||
|
// address string, reporting false when it is not a well-formed reverse name.
|
||||||
|
func parseInAddrArpa(labelPart string) (string, bool) {
|
||||||
|
labels := strings.Split(labelPart, ".")
|
||||||
|
if len(labels) != 4 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
slices.Reverse(labels)
|
||||||
|
addr, err := netip.ParseAddr(strings.Join(labels, "."))
|
||||||
|
if err != nil || !addr.Is4() {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return addr.String(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
|
||||||
|
// address string, reporting false when it is not a well-formed reverse name.
|
||||||
|
func parseIP6Arpa(nibblePart string) (string, bool) {
|
||||||
|
nibbles := strings.Split(nibblePart, ".")
|
||||||
|
if len(nibbles) != 32 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
slices.Reverse(nibbles)
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, n := range nibbles {
|
||||||
|
if i > 0 && i%4 == 0 {
|
||||||
|
sb.WriteByte(':')
|
||||||
|
}
|
||||||
|
sb.WriteString(n)
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddr(sb.String())
|
||||||
|
if err != nil || !addr.Is6() {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return addr.String(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
|
||||||
|
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
|
||||||
|
// does not distinguish a missing name from a name that exists only with records
|
||||||
|
// of other types, so the name cannot be proven absent and must not be poisoned.
|
||||||
|
func rcodeForRecordError(err error) int {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||||
|
return dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
return dns.RcodeServerFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
|
||||||
|
// so the record can be packed. The chunks form one TXT resource record.
|
||||||
|
func chunkTXT(s string) []string {
|
||||||
|
const maxLen = 255
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return []string{s}
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []string
|
||||||
|
for len(s) > maxLen {
|
||||||
|
chunks = append(chunks, s[:maxLen])
|
||||||
|
s = s[maxLen:]
|
||||||
|
}
|
||||||
|
if len(s) > 0 {
|
||||||
|
chunks = append(chunks, s)
|
||||||
|
}
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
// FormatAnswers formats DNS resource records for logging.
|
// FormatAnswers formats DNS resource records for logging.
|
||||||
func FormatAnswers(answers []dns.RR) string {
|
func FormatAnswers(answers []dns.RR) string {
|
||||||
if len(answers) == 0 {
|
if len(answers) == 0 {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -121,6 +122,164 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
|||||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPtrQueryAddr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
qname string
|
||||||
|
want string
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
|
||||||
|
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||||
|
want: "2001:db8::1",
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
|
||||||
|
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
|
||||||
|
{name: "not a reverse name", qname: "example.com.", wantOK: false},
|
||||||
|
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, ok := ptrQueryAddr(tt.qname)
|
||||||
|
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
|
||||||
|
if tt.wantOK {
|
||||||
|
assert.Equal(t, tt.want, got, "parsed address mismatch")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockRecordResolver struct {
|
||||||
|
mx []*net.MX
|
||||||
|
txt []string
|
||||||
|
ns []*net.NS
|
||||||
|
srv []*net.SRV
|
||||||
|
cname string
|
||||||
|
ptr []string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
|
||||||
|
return m.mx, m.err
|
||||||
|
}
|
||||||
|
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
|
||||||
|
return m.txt, m.err
|
||||||
|
}
|
||||||
|
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
|
||||||
|
return m.ns, m.err
|
||||||
|
}
|
||||||
|
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||||
|
return "", m.srv, m.err
|
||||||
|
}
|
||||||
|
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
|
||||||
|
return m.cname, m.err
|
||||||
|
}
|
||||||
|
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
|
||||||
|
return m.ptr, m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupRecords(t *testing.T) {
|
||||||
|
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
|
||||||
|
|
||||||
|
t.Run("MX success", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TXT short string is one character-string", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TXT chunks long strings", func(t *testing.T) {
|
||||||
|
long := strings.Repeat("a", 300)
|
||||||
|
r := &mockRecordResolver{txt: []string{long}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
txt := rrs[0].(*dns.TXT).Txt
|
||||||
|
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
|
||||||
|
assert.Equal(t, 255, len(txt[0]))
|
||||||
|
assert.Equal(t, 45, len(txt[1]))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NS success", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SRV success", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME success", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{cname: "target.example.com."}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{cname: "example.com."}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PTR success", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
require.Len(t, rrs, 1)
|
||||||
|
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
assert.Empty(t, rrs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{err: notFound}
|
||||||
|
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
|
||||||
|
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
|
||||||
|
assert.Equal(t, dns.RcodeServerFailure, rcode)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported type is NODATA", func(t *testing.T) {
|
||||||
|
r := &mockRecordResolver{}
|
||||||
|
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, rcode)
|
||||||
|
assert.Empty(t, rrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripOPT(t *testing.T) {
|
func TestStripOPT(t *testing.T) {
|
||||||
rm := &dns.Msg{
|
rm := &dns.Msg{
|
||||||
Extra: []dns.RR{
|
Extra: []dns.RR{
|
||||||
|
|||||||
485
client/internal/dns/server_privileged_test.go
Normal file
485
client/internal/dns/server_privileged_test.go
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initUpstreamMap []handlerWrapper
|
||||||
|
initLocalZones []nbdns.CustomZone
|
||||||
|
initSerial uint64
|
||||||
|
inputSerial uint64
|
||||||
|
inputUpdate nbdns.Config
|
||||||
|
shouldFail bool
|
||||||
|
expectedUpstreamMap []handlerWrapper
|
||||||
|
expectedLocalQs []dns.Question
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Initial Config Should Succeed",
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New Config Should Succeed",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "netbird.io",
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
priority: PriorityLocal,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 2,
|
||||||
|
inputSerial: 1,
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldFail: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
|
initLocalZones: []nbdns.CustomZone{},
|
||||||
|
initUpstreamMap: nil,
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedUpstreamMap: []handlerWrapper{{
|
||||||
|
domain: ".",
|
||||||
|
priority: PriorityDefault,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disabled Service Should clean map",
|
||||||
|
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||||
|
initUpstreamMap: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &mockHandler{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
|
expectedUpstreamMap: nil,
|
||||||
|
expectedLocalQs: []dns.Question{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||||
|
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wgIface.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = dnsServer.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
||||||
|
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||||
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
|
if err != nil {
|
||||||
|
if testCase.shouldFail {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
||||||
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range testCase.expectedUpstreamMap {
|
||||||
|
found := false
|
||||||
|
for _, got := range dnsServer.dnsMuxHandlers {
|
||||||
|
if got.domain == expected.domain && got.priority == expected.priority {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, q := range testCase.expectedLocalQs {
|
||||||
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||||
|
Question: []dns.Question{q},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testCase.expectedLocalQs) > 0 {
|
||||||
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||||
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create stdnet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
opts := iface.WGIFaceOpts{
|
||||||
|
IFaceName: "utun2301",
|
||||||
|
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
||||||
|
WGPort: 33100,
|
||||||
|
WGPrivKey: privKey.String(),
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
TransportNet: newNet,
|
||||||
|
}
|
||||||
|
wgIface, err := iface.NewWGIFace(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create and init wireguard interface: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = wgIface.Close(); err != nil {
|
||||||
|
t.Logf("close wireguard interface: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("run DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||||
|
t.Logf("restore DNS settings on the host: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &local.Resolver{},
|
||||||
|
priority: PriorityUpstream,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||||
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server with regular configuration
|
||||||
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := update
|
||||||
|
update2.ServiceEnable = false
|
||||||
|
// Disable the server, stop the listener
|
||||||
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update3 := update2
|
||||||
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||||
|
// But service still get updates and we checking that we handle
|
||||||
|
// internal state in the right way
|
||||||
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -23,7 +22,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -104,466 +102,6 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
initUpstreamMap []handlerWrapper
|
|
||||||
initLocalZones []nbdns.CustomZone
|
|
||||||
initSerial uint64
|
|
||||||
inputSerial uint64
|
|
||||||
inputUpdate nbdns.Config
|
|
||||||
shouldFail bool
|
|
||||||
expectedUpstreamMap []handlerWrapper
|
|
||||||
expectedLocalQs []dns.Question
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Initial Config Should Succeed",
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.io",
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
priority: PriorityDefault,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "New Config Should Succeed",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: "netbird.io",
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
domain: "netbird.cloud",
|
|
||||||
priority: PriorityLocal,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 2,
|
|
||||||
inputSerial: 1,
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
shouldFail: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
|
||||||
initLocalZones: []nbdns.CustomZone{},
|
|
||||||
initUpstreamMap: nil,
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUpstreamMap: []handlerWrapper{{
|
|
||||||
domain: ".",
|
|
||||||
priority: PriorityDefault,
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
|
||||||
expectedUpstreamMap: nil,
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Disabled Service Should clean map",
|
|
||||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
|
||||||
initUpstreamMap: []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &mockHandler{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
initSerial: 0,
|
|
||||||
inputSerial: 1,
|
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
|
||||||
expectedUpstreamMap: nil,
|
|
||||||
expectedLocalQs: []dns.Question{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
|
||||||
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = wgIface.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = dnsServer.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
|
|
||||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
|
||||||
if err != nil {
|
|
||||||
if testCase.shouldFail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
|
|
||||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expected := range testCase.expectedUpstreamMap {
|
|
||||||
found := false
|
|
||||||
for _, got := range dnsServer.dnsMuxHandlers {
|
|
||||||
if got.domain == expected.domain && got.priority == expected.priority {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
|
||||||
responseWriter := &test.MockResponseWriter{
|
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
responseMSG = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, q := range testCase.expectedLocalQs {
|
|
||||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
|
||||||
Question: []dns.Question{q},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(testCase.expectedLocalQs) > 0 {
|
|
||||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
|
||||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
|
||||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
||||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create stdnet: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: "utun2301",
|
|
||||||
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: privKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
wgIface, err := iface.NewWGIFace(opts)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("build interface wireguard: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = wgIface.Create()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create and init wireguard interface: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = wgIface.Close(); err != nil {
|
|
||||||
t.Logf("close wireguard interface: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
||||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
||||||
t.Errorf("set packet filter: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
|
||||||
WgInterface: wgIface,
|
|
||||||
CustomAddress: "",
|
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
|
||||||
StateManager: nil,
|
|
||||||
DisableSys: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("create DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dnsServer.Initialize()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("run DNS server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
||||||
t.Logf("restore DNS settings on the host: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
dnsServer.dnsMuxHandlers = []handlerWrapper{
|
|
||||||
{
|
|
||||||
domain: zoneRecords[0].Name,
|
|
||||||
handler: &local.Resolver{},
|
|
||||||
priority: PriorityUpstream,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
|
||||||
dnsServer.updateSerial = 0
|
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: netip.MustParseAddr("8.8.4.4"),
|
|
||||||
NSType: nbdns.UDPNameServerType,
|
|
||||||
Port: 53,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
update := nbdns.Config{
|
|
||||||
ServiceEnable: true,
|
|
||||||
CustomZones: []nbdns.CustomZone{
|
|
||||||
{
|
|
||||||
Domain: "netbird.cloud",
|
|
||||||
Records: zoneRecords,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NameServerGroups: []*nbdns.NameServerGroup{
|
|
||||||
{
|
|
||||||
Domains: []string{"netbird.io"},
|
|
||||||
NameServers: nameServers,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NameServers: nameServers,
|
|
||||||
Primary: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server with regular configuration
|
|
||||||
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update2 := update
|
|
||||||
update2.ServiceEnable = false
|
|
||||||
// Disable the server, stop the listener
|
|
||||||
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
update3 := update2
|
|
||||||
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
||||||
// But service still get updates and we checking that we handle
|
|
||||||
// internal state in the right way
|
|
||||||
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
||||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ const (
|
|||||||
|
|
||||||
type resolver interface {
|
type resolver interface {
|
||||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
|
||||||
|
LookupTXT(ctx context.Context, name string) ([]string, error)
|
||||||
|
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
|
||||||
|
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
|
||||||
|
LookupCNAME(ctx context.Context, host string) (string, error)
|
||||||
|
LookupAddr(ctx context.Context, addr string) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type firewaller interface {
|
type firewaller interface {
|
||||||
@@ -210,12 +216,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
|||||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
|
||||||
if network == "" {
|
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
|
||||||
f.writeResponse(logger, w, resp, qname, startTime)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
@@ -227,9 +227,46 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
reqHasEdns := query.IsEdns0() != nil
|
||||||
|
|
||||||
|
switch question.Qtype {
|
||||||
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
|
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
|
||||||
|
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
|
||||||
|
f.handleRecordQuery(ctx, logger, w, resp, startTime)
|
||||||
|
default:
|
||||||
|
// The domain is routed here, so any other type is answered NODATA
|
||||||
|
// (NOERROR, empty answer) rather than falling back to a resolver that
|
||||||
|
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
|
||||||
|
// client tell this capability-driven NODATA apart from an
|
||||||
|
// authoritative one. The OPT pseudo-record must not appear unless the
|
||||||
|
// query advertised EDNS0.
|
||||||
|
if reqHasEdns {
|
||||||
|
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
|
||||||
|
}
|
||||||
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
|
||||||
|
// resolved-IP state, and caches the answer for resilience on upstream failure.
|
||||||
|
func (f *DNSForwarder) handleAddressQuery(
|
||||||
|
ctx context.Context,
|
||||||
|
logger *log.Entry,
|
||||||
|
w dns.ResponseWriter,
|
||||||
|
resp *dns.Msg,
|
||||||
|
mostSpecificResId route.ResID,
|
||||||
|
matchingEntries []*ForwarderEntry,
|
||||||
|
reqHasEdns bool,
|
||||||
|
startTime time.Time,
|
||||||
|
) {
|
||||||
|
question := resp.Question[0]
|
||||||
|
qname := strings.ToLower(question.Name)
|
||||||
|
|
||||||
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +277,25 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
|||||||
f.writeResponse(logger, w, resp, qname, startTime)
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
|
||||||
|
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
|
||||||
|
// the routed name is never poisoned with NXDOMAIN.
|
||||||
|
func (f *DNSForwarder) handleRecordQuery(
|
||||||
|
ctx context.Context,
|
||||||
|
logger *log.Entry,
|
||||||
|
w dns.ResponseWriter,
|
||||||
|
resp *dns.Msg,
|
||||||
|
startTime time.Time,
|
||||||
|
) {
|
||||||
|
question := resp.Question[0]
|
||||||
|
qname := strings.ToLower(question.Name)
|
||||||
|
|
||||||
|
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
|
||||||
|
resp.Rcode = rcode
|
||||||
|
resp.Answer = append(resp.Answer, records...)
|
||||||
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
|||||||
@@ -133,6 +133,41 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
|||||||
return args.Get(0).([]netip.Addr), args.Error(1)
|
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
|
||||||
|
args := m.Called(ctx, name)
|
||||||
|
recs, _ := args.Get(0).([]*net.MX)
|
||||||
|
return recs, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
|
||||||
|
args := m.Called(ctx, name)
|
||||||
|
recs, _ := args.Get(0).([]string)
|
||||||
|
return recs, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
|
||||||
|
args := m.Called(ctx, name)
|
||||||
|
recs, _ := args.Get(0).([]*net.NS)
|
||||||
|
return recs, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
|
||||||
|
args := m.Called(ctx, service, proto, name)
|
||||||
|
recs, _ := args.Get(1).([]*net.SRV)
|
||||||
|
return args.String(0), recs, args.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
|
||||||
|
args := m.Called(ctx, host)
|
||||||
|
return args.String(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
|
||||||
|
args := m.Called(ctx, addr)
|
||||||
|
recs, _ := args.Get(0).([]string)
|
||||||
|
return recs, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -545,12 +580,15 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||||
|
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
|
||||||
|
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
queryType uint16
|
queryType uint16
|
||||||
queryDomain string
|
queryDomain string
|
||||||
configured string
|
configured string
|
||||||
expectedCode int
|
expectedCode int
|
||||||
|
expectEDE bool
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -562,28 +600,13 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
description: "RFC compliant REFUSED for unauthorized queries",
|
description: "RFC compliant REFUSED for unauthorized queries",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unsupported query type returns NOTIMP",
|
name: "unsupported query type returns NODATA",
|
||||||
queryType: dns.TypeMX,
|
queryType: dns.TypeCAA,
|
||||||
queryDomain: "example.com",
|
queryDomain: "example.com",
|
||||||
configured: "example.com",
|
configured: "example.com",
|
||||||
expectedCode: dns.RcodeNotImplemented,
|
expectedCode: dns.RcodeSuccess,
|
||||||
description: "RFC compliant NOTIMP for unsupported types",
|
expectEDE: true,
|
||||||
},
|
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
|
||||||
{
|
|
||||||
name: "CNAME query returns NOTIMP",
|
|
||||||
queryType: dns.TypeCNAME,
|
|
||||||
queryDomain: "example.com",
|
|
||||||
configured: "example.com",
|
|
||||||
expectedCode: dns.RcodeNotImplemented,
|
|
||||||
description: "CNAME queries not supported",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "TXT query returns NOTIMP",
|
|
||||||
queryType: dns.TypeTXT,
|
|
||||||
queryDomain: "example.com",
|
|
||||||
configured: "example.com",
|
|
||||||
expectedCode: dns.RcodeNotImplemented,
|
|
||||||
description: "TXT queries not supported",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,6 +622,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
|
|
||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||||
|
query.SetEdns0(dns.DefaultMsgSize, false)
|
||||||
|
|
||||||
// Capture the written response
|
// Capture the written response
|
||||||
var writtenResp *dns.Msg
|
var writtenResp *dns.Msg
|
||||||
@@ -614,10 +638,213 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
|
||||||
|
|
||||||
|
if tt.expectEDE {
|
||||||
|
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
|
||||||
|
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
|
||||||
|
"unsupported type NODATA should carry EDE Not Supported")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasEDE(m *dns.Msg, code uint16) bool {
|
||||||
|
opt := m.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, o := range opt.Option {
|
||||||
|
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_RecordQueries(t *testing.T) {
|
||||||
|
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
|
||||||
|
|
||||||
|
t.Run("MX records are forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupMX", mock.Anything, "example.com.").
|
||||||
|
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
mx, ok := resp.Answer[0].(*dns.MX)
|
||||||
|
require.True(t, ok, "answer should be an MX record")
|
||||||
|
assert.Equal(t, uint16(10), mx.Preference)
|
||||||
|
assert.Equal(t, "mail.example.com.", mx.Mx)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
// A not-found cannot prove the name is absent (it may exist with only
|
||||||
|
// other record types), so it must answer NODATA, never NXDOMAIN.
|
||||||
|
mockResolver.On("LookupMX", mock.Anything, "example.com.").
|
||||||
|
Return(nil, notFound).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
|
||||||
|
assert.Empty(t, resp.Answer)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NS records are forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupNS", mock.Anything, "example.com.").
|
||||||
|
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
ns, ok := resp.Answer[0].(*dns.NS)
|
||||||
|
require.True(t, ok, "answer should be an NS record")
|
||||||
|
assert.Equal(t, "ns1.example.com.", ns.Ns)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing NS is NODATA", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupNS", mock.Anything, "example.com.").
|
||||||
|
Return(nil, notFound).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
assert.Empty(t, resp.Answer)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SRV records are forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
|
||||||
|
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
srv, ok := resp.Answer[0].(*dns.SRV)
|
||||||
|
require.True(t, ok, "answer should be an SRV record")
|
||||||
|
assert.Equal(t, "sip.example.com.", srv.Target)
|
||||||
|
assert.Equal(t, uint16(5060), srv.Port)
|
||||||
|
assert.Equal(t, uint16(10), srv.Priority)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing SRV is NODATA", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
|
||||||
|
Return("", nil, notFound).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
assert.Empty(t, resp.Answer)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TXT records are forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
|
||||||
|
Return([]string{"v=spf1 -all"}, nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
txt, ok := resp.Answer[0].(*dns.TXT)
|
||||||
|
require.True(t, ok, "answer should be a TXT record")
|
||||||
|
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME record is forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
|
||||||
|
|
||||||
|
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
|
||||||
|
Return("target.example.com.", nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
require.True(t, ok, "answer should be a CNAME record")
|
||||||
|
assert.Equal(t, "target.example.com.", cname.Target)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
|
||||||
|
|
||||||
|
// No CNAME exists: LookupCNAME echoes the queried name back.
|
||||||
|
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
|
||||||
|
Return("example.com.", nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PTR record is forwarded", func(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
|
||||||
|
|
||||||
|
// The reverse name is parsed back to the address LookupAddr expects.
|
||||||
|
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
|
||||||
|
Return([]string{"host.example.com."}, nil).Once()
|
||||||
|
|
||||||
|
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||||
|
require.True(t, ok, "answer should be a PTR record")
|
||||||
|
assert.Equal(t, "host.example.com.", ptr.Ptr)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
|
||||||
|
t.Helper()
|
||||||
|
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||||
|
forwarder.resolver = r
|
||||||
|
|
||||||
|
d, err := domain.FromString(configured)
|
||||||
|
require.NoError(t, err)
|
||||||
|
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
|
||||||
|
return forwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
|
||||||
|
t.Helper()
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(qname), qtype)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "expected response to be written")
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -82,6 +82,12 @@ const (
|
|||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
disableAutoUpdate = "disabled"
|
disableAutoUpdate = "disabled"
|
||||||
|
|
||||||
|
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
|
||||||
|
// check gathering. The gathering runs uncancellable system calls (process scan,
|
||||||
|
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
|
||||||
|
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
|
||||||
|
systemInfoTimeout = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
@@ -166,6 +172,7 @@ type EngineServices struct {
|
|||||||
StateManager *statemanager.Manager
|
StateManager *statemanager.Manager
|
||||||
UpdateManager *updater.Manager
|
UpdateManager *updater.Manager
|
||||||
ClientMetrics *metrics.ClientMetrics
|
ClientMetrics *metrics.ClientMetrics
|
||||||
|
MetricsCtx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -258,6 +265,7 @@ type Engine struct {
|
|||||||
|
|
||||||
// clientMetrics collects and pushes metrics
|
// clientMetrics collects and pushes metrics
|
||||||
clientMetrics *metrics.ClientMetrics
|
clientMetrics *metrics.ClientMetrics
|
||||||
|
metricsCtx context.Context
|
||||||
|
|
||||||
jobExecutor *jobexec.Executor
|
jobExecutor *jobexec.Executor
|
||||||
jobExecutorWG sync.WaitGroup
|
jobExecutorWG sync.WaitGroup
|
||||||
@@ -310,6 +318,7 @@ func NewEngine(
|
|||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
jobExecutor: jobexec.NewExecutor(),
|
jobExecutor: jobexec.NewExecutor(),
|
||||||
clientMetrics: services.ClientMetrics,
|
clientMetrics: services.ClientMetrics,
|
||||||
|
metricsCtx: services.MetricsCtx,
|
||||||
updateManager: services.UpdateManager,
|
updateManager: services.UpdateManager,
|
||||||
syncStoreDir: config.StateDir,
|
syncStoreDir: config.StateDir,
|
||||||
}
|
}
|
||||||
@@ -895,6 +904,16 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// phase times a sync sub-phase: it returns a function that records the elapsed
|
||||||
|
// duration when called. Starting the timer at the call site keeps inter-phase
|
||||||
|
// glue code out of the measurement.
|
||||||
|
func (e *Engine) phase(name string) func() {
|
||||||
|
start := time.Now()
|
||||||
|
return func() {
|
||||||
|
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -914,7 +933,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
done := e.phase("netbird_config")
|
||||||
|
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
||||||
|
done()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -928,11 +950,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
done = e.phase("checks")
|
||||||
|
err = e.updateChecksIfNew(update.Checks)
|
||||||
|
done()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done = e.phase("persist")
|
||||||
e.persistSyncResponse(update)
|
e.persistSyncResponse(update)
|
||||||
|
done()
|
||||||
|
|
||||||
// only apply new changes and ignore old ones
|
// only apply new changes and ignore old ones
|
||||||
if err := e.updateNetworkMap(nm); err != nil {
|
if err := e.updateNetworkMap(nm); err != nil {
|
||||||
@@ -973,6 +1000,8 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
|||||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.handleMetricsUpdate(wCfg.GetMetrics())
|
||||||
|
|
||||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||||
log.Warnf("Failed to update DNS server config: %v", err)
|
log.Warnf("Failed to update DNS server config: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1042,6 +1071,14 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
|||||||
return e.flowManager.Update(flowConfig)
|
return e.flowManager.Update(flowConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleMetricsUpdate(config *mgmProto.MetricsConfig) {
|
||||||
|
if config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("received metrics configuration from management: enabled=%v", config.GetEnabled())
|
||||||
|
e.clientMetrics.UpdatePushFromMgm(e.metricsCtx, config.GetEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||||
if config.GetInterval() == nil {
|
if config.GetInterval() == nil {
|
||||||
return nil, errors.New("flow interval is nil")
|
return nil, errors.New("flow interval is nil")
|
||||||
@@ -1066,11 +1103,22 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
}
|
}
|
||||||
e.checks = checks
|
e.checks = checks
|
||||||
|
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
// Gathering timed out; skip the meta sync this cycle rather than blocking the
|
||||||
info = system.GetInfo(e.ctx)
|
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
e.applyInfoFlags(info)
|
||||||
|
|
||||||
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
|
return fmt.Errorf("could not sync meta: error %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
|
||||||
|
func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||||
info.SetFlags(
|
info.SetFlags(
|
||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
@@ -1089,12 +1137,20 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||||
log.Errorf("could not sync meta: error %s", err)
|
// can be excluded from the reported network addresses; the interface coming and
|
||||||
return err
|
// going otherwise churns the peer meta on the management server.
|
||||||
|
func (e *Engine) overlayAddresses() []netip.Addr {
|
||||||
|
var ips []netip.Addr
|
||||||
|
if e.config.WgAddr.IP.IsValid() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IP)
|
||||||
}
|
}
|
||||||
return nil
|
if e.config.WgAddr.HasIPv6() {
|
||||||
|
ips = append(ips, e.config.WgAddr.IPv6)
|
||||||
|
}
|
||||||
|
return ips
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
@@ -1240,31 +1296,15 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
// Gathering timed out; connect the stream with base info so management
|
||||||
|
// connectivity still comes up rather than blocking here.
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx)
|
||||||
}
|
}
|
||||||
info.SetFlags(
|
e.applyInfoFlags(info)
|
||||||
e.config.RosenpassEnabled,
|
|
||||||
e.config.RosenpassPermissive,
|
|
||||||
&e.config.ServerSSHAllowed,
|
|
||||||
e.config.DisableClientRoutes,
|
|
||||||
e.config.DisableServerRoutes,
|
|
||||||
e.config.DisableDNS,
|
|
||||||
e.config.DisableFirewall,
|
|
||||||
e.config.BlockLANAccess,
|
|
||||||
e.config.BlockInbound,
|
|
||||||
e.config.DisableIPv6,
|
|
||||||
e.config.LazyConnectionEnabled,
|
|
||||||
e.config.EnableSSHRoot,
|
|
||||||
e.config.EnableSSHSFTP,
|
|
||||||
e.config.EnableSSHLocalPortForwarding,
|
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
|
||||||
e.config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1357,13 +1397,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
|
||||||
|
|
||||||
|
done := e.phase("dns_server")
|
||||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
done()
|
||||||
|
|
||||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||||
|
|
||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
|
done = e.phase("routes_classify")
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||||
|
|
||||||
@@ -1372,29 +1415,60 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
||||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||||
}
|
}
|
||||||
|
done()
|
||||||
|
|
||||||
|
done = e.phase("routes_apply")
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update routes: %v", err)
|
log.Errorf("failed to update routes: %v", err)
|
||||||
}
|
}
|
||||||
|
done()
|
||||||
|
|
||||||
|
done = e.phase("filtering")
|
||||||
if e.acl != nil {
|
if e.acl != nil {
|
||||||
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
||||||
}
|
}
|
||||||
|
done()
|
||||||
|
|
||||||
|
done = e.phase("dns_forwarder")
|
||||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||||
|
done()
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
|
done = e.phase("forward_rules")
|
||||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
}
|
}
|
||||||
|
done()
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
|
done = e.phase("offline_peers")
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
|
done()
|
||||||
|
|
||||||
|
remotePeers, err := e.reconcilePeers(networkMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
|
done = e.phase("lazy_exclude")
|
||||||
|
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||||
|
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||||
|
done()
|
||||||
|
|
||||||
|
e.networkSerial = serial
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconcilePeers applies the remote peer list from the network map (removing,
|
||||||
|
// modifying and adding peers, then updating SSH config) and returns the remote
|
||||||
|
// peers with our own peer filtered out, for use by later sync steps.
|
||||||
|
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
|
||||||
// Filter out own peer from the remote peers list
|
// Filter out own peer from the remote peers list
|
||||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||||
@@ -1409,42 +1483,43 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
return remotePeers, nil
|
||||||
err := e.removePeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.modifyPeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.addNewPeers(remotePeers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
|
|
||||||
e.updatePeerSSHHostKeys(remotePeers)
|
|
||||||
|
|
||||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
|
||||||
log.Warnf("failed to update SSH client config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
done := e.phase("removed_peers")
|
||||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
err := e.removePeers(remotePeers)
|
||||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
done()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
e.networkSerial = serial
|
done = e.phase("modified_peers")
|
||||||
|
err = e.modifyPeers(remotePeers)
|
||||||
|
done()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
done = e.phase("added_peers")
|
||||||
|
err = e.addNewPeers(remotePeers)
|
||||||
|
done()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
|
e.updatePeerSSHHostKeys(remotePeers)
|
||||||
|
|
||||||
|
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
|
||||||
|
return remotePeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||||
|
|||||||
565
client/internal/engine_privileged_test.go
Normal file
565
client/internal/engine_privileged_test.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEngine_SSH(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(
|
||||||
|
ctx, cancel,
|
||||||
|
&EngineConfig{
|
||||||
|
WgIfaceName: "utun101",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
ServerSSHAllowed: true,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
SSHKey: sshKey,
|
||||||
|
},
|
||||||
|
EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
},
|
||||||
|
MobileDependency{},
|
||||||
|
)
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.21/24"},
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
||||||
|
networkMap := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 6,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// SSH server is enabled, therefore SSH config should be applied
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 7,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now remove peer
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 8,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// time.Sleep(250 * time.Millisecond)
|
||||||
|
assert.NotNil(t, engine.sshServer)
|
||||||
|
|
||||||
|
// now disable SSH server
|
||||||
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
|
Serial: 9,
|
||||||
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
|
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.updateNetworkMap(networkMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Nil(t, engine.sshServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Sync(t *testing.T) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// feed updates to Engine via mocked Management client
|
||||||
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
|
defer close(updates)
|
||||||
|
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
|
for msg := range updates {
|
||||||
|
err := msgHandler(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
|
WgIfaceName: "utun103",
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
|
engine.ctx = ctx
|
||||||
|
|
||||||
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.10/24"},
|
||||||
|
}
|
||||||
|
peer2 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.11/24"},
|
||||||
|
}
|
||||||
|
peer3 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.12/24"},
|
||||||
|
}
|
||||||
|
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||||
|
updates <- &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.After(time.Second * 2)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("timeout while waiting for test to finish")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_MultiplePeers(t *testing.T) {
|
||||||
|
// log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigServer, signalAddr, err := startSignal(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sigServer.Stop()
|
||||||
|
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
|
mu := sync.Mutex{}
|
||||||
|
engines := []*Engine{}
|
||||||
|
numPeers := 10
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(numPeers)
|
||||||
|
// create and start peers
|
||||||
|
for i := 0; i < numPeers; i++ {
|
||||||
|
j := i
|
||||||
|
go func() {
|
||||||
|
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
||||||
|
if err != nil {
|
||||||
|
wg.Done()
|
||||||
|
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engine.dnsServer = &dns.MockServer{}
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
err = engine.Start(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engines = append(engines, engine)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait until all have been created and started
|
||||||
|
wg.Wait()
|
||||||
|
if len(engines) != numPeers {
|
||||||
|
t.Fatal("not all peers were started")
|
||||||
|
}
|
||||||
|
// check whether all the peer have expected peers connected
|
||||||
|
|
||||||
|
expectedConnected := numPeers * (numPeers - 1)
|
||||||
|
|
||||||
|
// adjust according to timeouts
|
||||||
|
timeout := 50 * time.Second
|
||||||
|
timeoutChan := time.After(timeout)
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
||||||
|
break loop
|
||||||
|
case <-ticker.C:
|
||||||
|
totalConnected := 0
|
||||||
|
for _, engine := range engines {
|
||||||
|
totalConnected += getConnectedPeers(engine)
|
||||||
|
}
|
||||||
|
if totalConnected == expectedConnected {
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
log.Infof("total connected=%d", totalConnected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cleanup test
|
||||||
|
for n, peerEngine := range engines {
|
||||||
|
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
||||||
|
errStop := peerEngine.mgmClient.Close()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
||||||
|
}
|
||||||
|
errStop = peerEngine.Stop()
|
||||||
|
if errStop != nil {
|
||||||
|
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
kaep = keepalive.EnforcementPolicy{
|
||||||
|
MinTime: 15 * time.Second,
|
||||||
|
PermitWithoutStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kasp = keepalive.ServerParameters{
|
||||||
|
MaxConnectionIdle: 15 * time.Second,
|
||||||
|
MaxConnectionAgeGrace: 5 * time.Second,
|
||||||
|
Time: 5 * time.Second,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ifaceName string
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||||
|
} else {
|
||||||
|
ifaceName = fmt.Sprintf("wt%d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgPort := 33100 + i
|
||||||
|
conf := &EngineConfig{
|
||||||
|
WgIfaceName: ifaceName,
|
||||||
|
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: wgPort,
|
||||||
|
MTU: iface.DefaultMTU,
|
||||||
|
}
|
||||||
|
|
||||||
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmtClient,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{}), nil
|
||||||
|
e.ctx = ctx
|
||||||
|
return e, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
config := &config.Config{
|
||||||
|
Stuns: []*config.Host{},
|
||||||
|
TURNConfig: &config.TURNConfig{},
|
||||||
|
Relay: &config.Relay{
|
||||||
|
Addresses: []string{"127.0.0.1:1234"},
|
||||||
|
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||||
|
Secret: "222222222222222222",
|
||||||
|
},
|
||||||
|
Signal: &config.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: "localhost:10000",
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.ExtraSettings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
|
func getConnectedPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
i := 0
|
||||||
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
|
if conn.IsConnected() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPeers(e *Engine) int {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
return len(e.peerStore.PeersPubKey())
|
||||||
|
}
|
||||||
@@ -6,37 +6,18 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -50,18 +31,7 @@ import (
|
|||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||||
@@ -69,25 +39,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
"github.com/netbirdio/netbird/shared/netiputil"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
kaep = keepalive.EnforcementPolicy{
|
|
||||||
MinTime: 15 * time.Second,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
kasp = keepalive.ServerParameters{
|
|
||||||
MaxConnectionIdle: 15 * time.Second,
|
|
||||||
MaxConnectionAgeGrace: 5 * time.Second,
|
|
||||||
Time: 5 * time.Second,
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockWGIface struct {
|
type MockWGIface struct {
|
||||||
CreateFunc func() error
|
CreateFunc func() error
|
||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||||
@@ -224,6 +178,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) MTU() uint16 {
|
||||||
|
return 1280
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -234,129 +192,6 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_SSH(t *testing.T) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(
|
|
||||||
ctx, cancel,
|
|
||||||
&EngineConfig{
|
|
||||||
WgIfaceName: "utun101",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
ServerSSHAllowed: true,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
SSHKey: sshKey,
|
|
||||||
},
|
|
||||||
EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
},
|
|
||||||
MobileDependency{},
|
|
||||||
)
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
peerWithSSH := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.21/24"},
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSH server is not enabled so SSH config of a remote peer should be ignored
|
|
||||||
networkMap := &mgmtProto.NetworkMap{
|
|
||||||
Serial: 6,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// SSH server is enabled, therefore SSH config should be applied
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 7,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{
|
|
||||||
SshEnabled: true,
|
|
||||||
JwtConfig: &mgmtProto.JWTConfig{
|
|
||||||
Issuer: "test-issuer",
|
|
||||||
Audience: "test-audience",
|
|
||||||
KeysLocation: "test-keys",
|
|
||||||
MaxTokenAge: 3600,
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now remove peer
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 8,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
|
||||||
assert.NotNil(t, engine.sshServer)
|
|
||||||
|
|
||||||
// now disable SSH server
|
|
||||||
networkMap = &mgmtProto.NetworkMap{
|
|
||||||
Serial: 9,
|
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||||
// Test that SSH server start/stop logic works based on config
|
// Test that SSH server start/stop logic works based on config
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
@@ -631,97 +466,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_Sync(t *testing.T) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// feed updates to Engine via mocked Management client
|
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
|
||||||
defer close(updates)
|
|
||||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
|
||||||
for msg := range updates {
|
|
||||||
err := msgHandler(msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
|
||||||
WgIfaceName: "utun103",
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}, EngineServices{
|
|
||||||
SignalClient: &signal.MockClient{},
|
|
||||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{})
|
|
||||||
engine.ctx = ctx
|
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := engine.Stop()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
peer1 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.10/24"},
|
|
||||||
}
|
|
||||||
peer2 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.11/24"},
|
|
||||||
}
|
|
||||||
peer3 := &mgmtProto.RemotePeerConfig{
|
|
||||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
|
||||||
AllowedIps: []string{"100.64.0.12/24"},
|
|
||||||
}
|
|
||||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
|
||||||
updates <- &mgmtProto.SyncResponse{
|
|
||||||
NetworkMap: &mgmtProto.NetworkMap{
|
|
||||||
Serial: 10,
|
|
||||||
PeerConfig: nil,
|
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
|
||||||
RemotePeersIsEmpty: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := time.After(time.Second * 2)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
t.Fatalf("timeout while waiting for test to finish")
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if getPeers(engine) == 3 && engine.networkSerial == 10 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1105,104 +849,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_MultiplePeers(t *testing.T) {
|
|
||||||
// log.SetLevel(log.DebugLevel)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigServer, signalAddr, err := startSignal(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer sigServer.Stop()
|
|
||||||
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer mgmtServer.GracefulStop()
|
|
||||||
|
|
||||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
|
||||||
|
|
||||||
mu := sync.Mutex{}
|
|
||||||
engines := []*Engine{}
|
|
||||||
numPeers := 10
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
wg.Add(numPeers)
|
|
||||||
// create and start peers
|
|
||||||
for i := 0; i < numPeers; i++ {
|
|
||||||
j := i
|
|
||||||
go func() {
|
|
||||||
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
|
|
||||||
if err != nil {
|
|
||||||
wg.Done()
|
|
||||||
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engine.dnsServer = &dns.MockServer{}
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
|
||||||
err = engine.Start(nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
|
||||||
wg.Done()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
engines = append(engines, engine)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait until all have been created and started
|
|
||||||
wg.Wait()
|
|
||||||
if len(engines) != numPeers {
|
|
||||||
t.Fatal("not all peers was started")
|
|
||||||
}
|
|
||||||
// check whether all the peer have expected peers connected
|
|
||||||
|
|
||||||
expectedConnected := numPeers * (numPeers - 1)
|
|
||||||
|
|
||||||
// adjust according to timeouts
|
|
||||||
timeout := 50 * time.Second
|
|
||||||
timeoutChan := time.After(timeout)
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeoutChan:
|
|
||||||
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
|
|
||||||
break loop
|
|
||||||
case <-ticker.C:
|
|
||||||
totalConnected := 0
|
|
||||||
for _, engine := range engines {
|
|
||||||
totalConnected += getConnectedPeers(engine)
|
|
||||||
}
|
|
||||||
if totalConnected == expectedConnected {
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
log.Infof("total connected=%d", totalConnected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// cleanup test
|
|
||||||
for n, peerEngine := range engines {
|
|
||||||
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
|
|
||||||
errStop := peerEngine.mgmClient.Close()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close management clients from engine: ", errStop)
|
|
||||||
}
|
|
||||||
errStop = peerEngine.Stop()
|
|
||||||
if errStop != nil {
|
|
||||||
log.Infoln("got error trying to close testing peers engine: ", errStop)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||||
ifaceList, err := net.Interfaces()
|
ifaceList, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1526,187 +1172,6 @@ func TestCompareNetIPLists(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ifaceName string
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
|
||||||
} else {
|
|
||||||
ifaceName = fmt.Sprintf("wt%d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgPort := 33100 + i
|
|
||||||
conf := &EngineConfig{
|
|
||||||
WgIfaceName: ifaceName,
|
|
||||||
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: wgPort,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
|
||||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
|
||||||
SignalClient: signalClient,
|
|
||||||
MgmClient: mgmtClient,
|
|
||||||
RelayManager: relayMgr,
|
|
||||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
|
||||||
}, MobileDependency{}), nil
|
|
||||||
e.ctx = ctx
|
|
||||||
return e, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
|
||||||
require.NoError(t, err)
|
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
config := &config.Config{
|
|
||||||
Stuns: []*config.Host{},
|
|
||||||
TURNConfig: &config.TURNConfig{},
|
|
||||||
Relay: &config.Relay{
|
|
||||||
Addresses: []string{"127.0.0.1:1234"},
|
|
||||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
|
||||||
Secret: "222222222222222222",
|
|
||||||
},
|
|
||||||
Signal: &config.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: "localhost:10000",
|
|
||||||
},
|
|
||||||
Datadir: dataDir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.Settings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
settingsMockManager.EXPECT().
|
|
||||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
|
||||||
Return(&types.ExtraSettings{}, nil).
|
|
||||||
AnyTimes()
|
|
||||||
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
|
||||||
func getConnectedPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
i := 0
|
|
||||||
for _, id := range e.peerStore.PeersPubKey() {
|
|
||||||
conn, _ := e.peerStore.PeerConn(id)
|
|
||||||
if conn.IsConnected() {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPeers(e *Engine) int {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
return len(e.peerStore.PeersPubKey())
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
b, err := netiputil.EncodePrefix(p)
|
b, err := netiputil.EncodePrefix(p)
|
||||||
|
|||||||
@@ -44,4 +44,5 @@ type wgIfaceBase interface {
|
|||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||||
|
MTU() uint16
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,15 +119,16 @@ func (d *BindListener) ReadPackets() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
|
||||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = d.lazyConn.Close()
|
_ = d.lazyConn.Close()
|
||||||
d.bind.RemoveEndpoint(d.fakeIP)
|
d.bind.RemoveEndpoint(d.fakeIP)
|
||||||
d.done.Done()
|
d.done.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
|
||||||
|
func (d *BindListener) CapturedPacket() []byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close stops the listener and cleans up resources.
|
// Close stops the listener and cleans up resources.
|
||||||
func (d *BindListener) Close() {
|
func (d *BindListener) Close() {
|
||||||
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||||
|
|||||||
@@ -45,10 +45,6 @@ type MockWGIfaceBind struct {
|
|||||||
endpointMgr *mockEndpointManager
|
endpointMgr *mockEndpointManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -68,6 +64,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
|||||||
return m.endpointMgr
|
return m.endpointMgr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) MTU() uint16 {
|
||||||
|
return 1280
|
||||||
|
}
|
||||||
|
|
||||||
func TestBindListener_Creation(t *testing.T) {
|
func TestBindListener_Creation(t *testing.T) {
|
||||||
mockEndpointMgr := newMockEndpointManager()
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
@@ -207,8 +207,9 @@ func TestManager_BindMode(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case peerConnID := <-mgr.OnActivityChan:
|
case ev := <-mgr.OnActivityChan:
|
||||||
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
|
||||||
|
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatal("timeout waiting for activity notification")
|
t.Fatal("timeout waiting for activity notification")
|
||||||
}
|
}
|
||||||
@@ -266,8 +267,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
|||||||
receivedPeers := make(map[peerid.ConnID]bool)
|
receivedPeers := make(map[peerid.ConnID]bool)
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
select {
|
select {
|
||||||
case peerConnID := <-mgr.OnActivityChan:
|
case ev := <-mgr.OnActivityChan:
|
||||||
receivedPeers[peerConnID] = true
|
receivedPeers[ev.PeerConnID] = true
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatal("timeout waiting for activity notifications")
|
t.Fatal("timeout waiting for activity notifications")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ package activity
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,6 +22,8 @@ type UDPListener struct {
|
|||||||
done sync.Mutex
|
done sync.Mutex
|
||||||
|
|
||||||
isClosed atomic.Bool
|
isClosed atomic.Bool
|
||||||
|
|
||||||
|
capturedPacket []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||||
@@ -46,9 +50,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||||
|
// The first packet that triggers activity is captured so it can be reinjected through the real
|
||||||
|
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
|
||||||
|
// dropped and WG would only retry after REKEY_TIMEOUT.
|
||||||
func (d *UDPListener) ReadPackets() {
|
func (d *UDPListener) ReadPackets() {
|
||||||
for {
|
for {
|
||||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
|
||||||
|
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if d.isClosed.Load() {
|
if d.isClosed.Load() {
|
||||||
d.peerCfg.Log.Infof("exit from activity listener")
|
d.peerCfg.Log.Infof("exit from activity listener")
|
||||||
@@ -62,20 +70,24 @@ func (d *UDPListener) ReadPackets() {
|
|||||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
d.peerCfg.Log.Infof("activity detected")
|
d.capturedPacket = slices.Clone(buf[:n])
|
||||||
|
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
|
||||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
|
||||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
// triggered activation.
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore close error as it may return "use of closed network connection" if already closed.
|
|
||||||
_ = d.conn.Close()
|
_ = d.conn.Close()
|
||||||
d.done.Unlock()
|
d.done.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
|
||||||
|
// Safe to call after ReadPackets returns.
|
||||||
|
func (d *UDPListener) CapturedPacket() []byte {
|
||||||
|
return d.capturedPacket
|
||||||
|
}
|
||||||
|
|
||||||
// Close stops the listener and cleans up resources.
|
// Close stops the listener and cleans up resources.
|
||||||
func (d *UDPListener) Close() {
|
func (d *UDPListener) Close() {
|
||||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||||
|
|||||||
@@ -19,17 +19,25 @@ import (
|
|||||||
type listener interface {
|
type listener interface {
|
||||||
ReadPackets()
|
ReadPackets()
|
||||||
Close()
|
Close()
|
||||||
|
CapturedPacket() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
|
||||||
|
// captured for reinjection through the real transport.
|
||||||
|
type Event struct {
|
||||||
|
PeerConnID peerid.ConnID
|
||||||
|
FirstPacket []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type WgInterface interface {
|
type WgInterface interface {
|
||||||
RemovePeer(peerKey string) error
|
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
MTU() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
OnActivityChan chan peerid.ConnID
|
OnActivityChan chan Event
|
||||||
|
|
||||||
wgIface WgInterface
|
wgIface WgInterface
|
||||||
|
|
||||||
@@ -41,7 +49,7 @@ type Manager struct {
|
|||||||
|
|
||||||
func NewManager(wgIface WgInterface) *Manager {
|
func NewManager(wgIface WgInterface) *Manager {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
OnActivityChan: make(chan Event, 1),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
peers: make(map[peerid.ConnID]listener),
|
peers: make(map[peerid.ConnID]listener),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
@@ -116,12 +124,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
|||||||
delete(m.peers, peerConnID)
|
delete(m.peers, peerConnID)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
m.notify(peerConnID)
|
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
func (m *Manager) notify(ev Event) {
|
||||||
select {
|
select {
|
||||||
case <-m.done:
|
case <-m.done:
|
||||||
case m.OnActivityChan <- peerConnID:
|
case m.OnActivityChan <- ev:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package activity
|
package activity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -25,10 +26,6 @@ func (m *MocPeer) ConnID() peerid.ConnID {
|
|||||||
type MocWGIface struct {
|
type MocWGIface struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m MocWGIface) RemovePeer(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -44,6 +41,10 @@ func (m MocWGIface) Address() wgaddr.Address {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m MocWGIface) MTU() uint16 {
|
||||||
|
return 1280
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerListener is a test helper to access listeners
|
// GetPeerListener is a test helper to access listeners
|
||||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -86,11 +87,15 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case peerConnID := <-mgr.OnActivityChan:
|
case ev := <-mgr.OnActivityChan:
|
||||||
if peerConnID != peerCfg1.PeerConnID {
|
if ev.PeerConnID != peerCfg1.PeerConnID {
|
||||||
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
|
||||||
|
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
|
||||||
}
|
}
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for activity")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
case ev := <-m.activityManager.OnActivityChan:
|
||||||
m.onPeerActivity(peerConnID)
|
m.onPeerActivity(ev)
|
||||||
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||||
m.onPeerInactivityTimedOut(peerIDs)
|
m.onPeerInactivityTimedOut(peerIDs)
|
||||||
}
|
}
|
||||||
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
func (m *Manager) onPeerActivity(ev activity.Event) {
|
||||||
m.managedPeersMu.Lock()
|
m.managedPeersMu.Lock()
|
||||||
defer m.managedPeersMu.Unlock()
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf("peer not found by conn id: %v", peerConnID)
|
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
|||||||
|
|
||||||
m.activateHAGroupPeers(mp.peerCfg)
|
m.activateHAGroupPeers(mp.peerCfg)
|
||||||
|
|
||||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||||
|
|||||||
@@ -17,4 +17,5 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
|
MTU() uint16
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ func getMetricsInterval() time.Duration {
|
|||||||
return interval
|
return interval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMetricsPushEnvSet returns true if NB_METRICS_PUSH_ENABLED is explicitly set (to any value).
|
||||||
|
// When set, the env var takes full precedence over management server configuration.
|
||||||
|
func isMetricsPushEnvSet() bool {
|
||||||
|
_, set := os.LookupEnv(EnvMetricsPushEnabled)
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
func isForceSending() bool {
|
func isForceSending() bool {
|
||||||
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
|
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
|
||||||
return force
|
return force
|
||||||
|
|||||||
@@ -120,6 +120,30 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
|||||||
m.trimLocked()
|
m.trimLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
|
||||||
|
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
|
||||||
|
agentInfo.DeploymentType.String(),
|
||||||
|
agentInfo.Version,
|
||||||
|
agentInfo.OS,
|
||||||
|
agentInfo.Arch,
|
||||||
|
agentInfo.peerID,
|
||||||
|
phase,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "netbird_sync_phase",
|
||||||
|
tags: tags,
|
||||||
|
fields: map[string]float64{
|
||||||
|
"duration_seconds": duration.Seconds(),
|
||||||
|
},
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
m.trimLocked()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||||
result := "success"
|
result := "success"
|
||||||
if !success {
|
if !success {
|
||||||
|
|||||||
@@ -78,6 +78,25 @@ Tags:
|
|||||||
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
|
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
|
||||||
- `arch`: CPU architecture (amd64, arm64, etc.)
|
- `arch`: CPU architecture (amd64, arm64, etc.)
|
||||||
|
|
||||||
|
### Sync Phase Timing
|
||||||
|
|
||||||
|
Measurement: `netbird_sync_phase`
|
||||||
|
|
||||||
|
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
|
||||||
|
|
||||||
|
| Field | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
|
||||||
|
|
||||||
|
Tags:
|
||||||
|
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
|
||||||
|
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
|
||||||
|
- `version`: NetBird version string
|
||||||
|
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
|
||||||
|
- `arch`: CPU architecture (amd64, arm64, etc.)
|
||||||
|
|
||||||
|
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
|
||||||
|
|
||||||
### Login Duration
|
### Login Duration
|
||||||
|
|
||||||
Measurement: `netbird_login`
|
Measurement: `netbird_login`
|
||||||
@@ -191,4 +210,52 @@ docker compose exec influxdb influx query \
|
|||||||
|
|
||||||
# Check ingest server health
|
# Check ingest server health
|
||||||
curl http://localhost:8087/health
|
curl http://localhost:8087/health
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Analyzing a Debug Bundle
|
||||||
|
|
||||||
|
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
|
||||||
|
|
||||||
|
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
|
||||||
|
|
||||||
|
### 1. Start the stack
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From this directory (client/internal/metrics/infra)
|
||||||
|
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
(`admin123` are throwaway local credentials — fine for offline analysis.)
|
||||||
|
|
||||||
|
### 2. Clear any previous data
|
||||||
|
|
||||||
|
So you only see this bundle:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
|
||||||
|
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Import the bundle's metrics.txt
|
||||||
|
|
||||||
|
InfluxDB is not exposed on the host, so import inside the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
|
||||||
|
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
|
||||||
|
--token admin123 --file /tmp/m.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
|
||||||
|
|
||||||
|
### 4. View the dashboards
|
||||||
|
|
||||||
|
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
|
||||||
|
|
||||||
|
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
|
||||||
|
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
|
||||||
|
|
||||||
|
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
|
||||||
|
|
||||||
|
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.
|
||||||
@@ -0,0 +1,259 @@
|
|||||||
|
{
|
||||||
|
"annotations": {
|
||||||
|
"list": []
|
||||||
|
},
|
||||||
|
"editable": true,
|
||||||
|
"fiscalYearStartMonth": 0,
|
||||||
|
"graphTooltip": 1,
|
||||||
|
"links": [],
|
||||||
|
"refresh": "",
|
||||||
|
"schemaVersion": 39,
|
||||||
|
"tags": [
|
||||||
|
"netbird",
|
||||||
|
"sync"
|
||||||
|
],
|
||||||
|
"templating": {
|
||||||
|
"list": [
|
||||||
|
{
|
||||||
|
"current": {
|
||||||
|
"text": "All",
|
||||||
|
"value": "$__all"
|
||||||
|
},
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
|
||||||
|
"includeAll": true,
|
||||||
|
"label": "version",
|
||||||
|
"multi": true,
|
||||||
|
"name": "version",
|
||||||
|
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
|
||||||
|
"refresh": 2,
|
||||||
|
"type": "query",
|
||||||
|
"allValue": ".*"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"time": {
|
||||||
|
"from": "now-2d",
|
||||||
|
"to": "now"
|
||||||
|
},
|
||||||
|
"timepicker": {},
|
||||||
|
"timezone": "",
|
||||||
|
"title": "NetBird Sync Phases (where time goes)",
|
||||||
|
"uid": "netbird-sync-phases",
|
||||||
|
"version": 1,
|
||||||
|
"panels": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"title": "Time per phase over time (stacked, ms)",
|
||||||
|
"type": "timeseries",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"gridPos": {
|
||||||
|
"h": 10,
|
||||||
|
"w": 24,
|
||||||
|
"x": 0,
|
||||||
|
"y": 0
|
||||||
|
},
|
||||||
|
"fieldConfig": {
|
||||||
|
"defaults": {
|
||||||
|
"unit": "ms",
|
||||||
|
"custom": {
|
||||||
|
"drawStyle": "bars",
|
||||||
|
"stacking": {
|
||||||
|
"mode": "normal",
|
||||||
|
"group": "A"
|
||||||
|
},
|
||||||
|
"fillOpacity": 80,
|
||||||
|
"lineWidth": 0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"overrides": []
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"legend": {
|
||||||
|
"displayMode": "table",
|
||||||
|
"placement": "right",
|
||||||
|
"calcs": [
|
||||||
|
"max",
|
||||||
|
"mean"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"tooltip": {
|
||||||
|
"mode": "multi",
|
||||||
|
"sort": "desc"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"refId": "A",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"title": "p95 per phase (ms)",
|
||||||
|
"type": "bargauge",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"gridPos": {
|
||||||
|
"h": 11,
|
||||||
|
"w": 12,
|
||||||
|
"x": 0,
|
||||||
|
"y": 10
|
||||||
|
},
|
||||||
|
"fieldConfig": {
|
||||||
|
"defaults": {
|
||||||
|
"unit": "ms",
|
||||||
|
"color": {
|
||||||
|
"mode": "continuous-GrYlRd"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"overrides": []
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"displayMode": "gradient",
|
||||||
|
"orientation": "horizontal",
|
||||||
|
"reduceOptions": {
|
||||||
|
"calcs": [
|
||||||
|
"lastNotNull"
|
||||||
|
],
|
||||||
|
"fields": "",
|
||||||
|
"values": false
|
||||||
|
},
|
||||||
|
"showUnfilled": true
|
||||||
|
},
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"refId": "A",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"title": "Per-phase stats (ms): mean / p95 / max",
|
||||||
|
"type": "table",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"gridPos": {
|
||||||
|
"h": 11,
|
||||||
|
"w": 12,
|
||||||
|
"x": 12,
|
||||||
|
"y": 10
|
||||||
|
},
|
||||||
|
"fieldConfig": {
|
||||||
|
"defaults": {
|
||||||
|
"unit": "ms"
|
||||||
|
},
|
||||||
|
"overrides": []
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"showHeader": true,
|
||||||
|
"sortBy": [
|
||||||
|
{
|
||||||
|
"displayName": "max",
|
||||||
|
"desc": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"transformations": [
|
||||||
|
{
|
||||||
|
"id": "merge",
|
||||||
|
"options": {}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"refId": "mean",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"refId": "p95",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"refId": "max",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4,
|
||||||
|
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
|
||||||
|
"type": "timeseries",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"gridPos": {
|
||||||
|
"h": 8,
|
||||||
|
"w": 24,
|
||||||
|
"x": 0,
|
||||||
|
"y": 21
|
||||||
|
},
|
||||||
|
"fieldConfig": {
|
||||||
|
"defaults": {
|
||||||
|
"unit": "ms",
|
||||||
|
"custom": {
|
||||||
|
"drawStyle": "points",
|
||||||
|
"pointSize": 5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"overrides": []
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"legend": {
|
||||||
|
"displayMode": "table",
|
||||||
|
"placement": "right",
|
||||||
|
"calcs": [
|
||||||
|
"max",
|
||||||
|
"mean"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"tooltip": {
|
||||||
|
"mode": "single"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"targets": [
|
||||||
|
{
|
||||||
|
"refId": "A",
|
||||||
|
"datasource": {
|
||||||
|
"type": "influxdb",
|
||||||
|
"uid": "influxdb"
|
||||||
|
},
|
||||||
|
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -19,7 +19,7 @@ const (
|
|||||||
defaultListenAddr = ":8087"
|
defaultListenAddr = ":8087"
|
||||||
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
|
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
|
||||||
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
|
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
|
||||||
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
|
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
|
||||||
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
|
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
|
||||||
maxTagValueLength = 64 // reject tag values longer than this
|
maxTagValueLength = 64 // reject tag values longer than this
|
||||||
)
|
)
|
||||||
@@ -59,6 +59,19 @@ var allowedMeasurements = map[string]measurementSpec{
|
|||||||
"peer_id": true,
|
"peer_id": true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"netbird_sync_phase": {
|
||||||
|
allowedFields: map[string]bool{
|
||||||
|
"duration_seconds": true,
|
||||||
|
},
|
||||||
|
allowedTags: map[string]bool{
|
||||||
|
"deployment_type": true,
|
||||||
|
"version": true,
|
||||||
|
"os": true,
|
||||||
|
"arch": true,
|
||||||
|
"peer_id": true,
|
||||||
|
"phase": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
"netbird_login": {
|
"netbird_login": {
|
||||||
allowedFields: map[string]bool{
|
allowedFields: map[string]bool{
|
||||||
"duration_seconds": true,
|
"duration_seconds": true,
|
||||||
|
|||||||
@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateLine_DurationTooLarge(t *testing.T) {
|
func TestValidateLine_DurationTooLarge(t *testing.T) {
|
||||||
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
|
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
|
||||||
err := validateLine(line)
|
err := validateLine(line)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "too large")
|
assert.Contains(t, err.Error(), "too large")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
|
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
|
||||||
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
|
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
|
||||||
err := validateLine(line)
|
err := validateLine(line)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "too large")
|
assert.Contains(t, err.Error(), "too large")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -56,6 +57,9 @@ type metricsImplementation interface {
|
|||||||
// RecordSyncDuration records how long it took to process a sync message
|
// RecordSyncDuration records how long it took to process a sync message
|
||||||
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
|
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
|
||||||
|
|
||||||
|
// RecordSyncPhase records how long a single sub-phase of sync processing took
|
||||||
|
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
|
||||||
|
|
||||||
// RecordLoginDuration records how long the login to management took
|
// RecordLoginDuration records how long the login to management took
|
||||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||||
|
|
||||||
@@ -72,7 +76,7 @@ type ClientMetrics struct {
|
|||||||
agentInfo AgentInfo
|
agentInfo AgentInfo
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
push *Push
|
push atomic.Pointer[Push]
|
||||||
pushMu sync.Mutex
|
pushMu sync.Mutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
pushCancel context.CancelFunc
|
pushCancel context.CancelFunc
|
||||||
@@ -127,6 +131,18 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
|||||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordSyncPhase records the duration of a single sub-phase of sync processing
|
||||||
|
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
agentInfo := c.agentInfo
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
|
||||||
|
}
|
||||||
|
|
||||||
// RecordLoginDuration records how long the login to management server took
|
// RecordLoginDuration records how long the login to management server took
|
||||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -152,10 +168,7 @@ func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) {
|
|||||||
c.agentInfo = agentInfo
|
c.agentInfo = agentInfo
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
c.pushMu.Lock()
|
if push := c.push.Load(); push != nil {
|
||||||
push := c.push
|
|
||||||
c.pushMu.Unlock()
|
|
||||||
if push != nil {
|
|
||||||
push.SetPeerID(agentInfo.peerID)
|
push.SetPeerID(agentInfo.peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,7 +182,7 @@ func (c *ClientMetrics) Export(w io.Writer) error {
|
|||||||
return c.impl.Export(w)
|
return c.impl.Export(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartPush starts periodic pushing of metrics with the given configuration
|
// StartPush starts periodic pushing of metrics with the given configuration.
|
||||||
// Precedence: PushConfig.ServerAddress > remote config server_url
|
// Precedence: PushConfig.ServerAddress > remote config server_url
|
||||||
func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -179,11 +192,58 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
|||||||
c.pushMu.Lock()
|
c.pushMu.Lock()
|
||||||
defer c.pushMu.Unlock()
|
defer c.pushMu.Unlock()
|
||||||
|
|
||||||
if c.push != nil {
|
if c.push.Load() != nil {
|
||||||
log.Warnf("metrics push already running")
|
log.Warnf("metrics push already running")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.startPushLocked(ctx, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopPush stops the periodic metrics push.
|
||||||
|
func (c *ClientMetrics) StopPush() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.pushMu.Lock()
|
||||||
|
defer c.pushMu.Unlock()
|
||||||
|
|
||||||
|
c.stopPushLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePushFromMgm updates metrics push based on management server configuration.
|
||||||
|
// If NB_METRICS_PUSH_ENABLED is explicitly set (true or false), management config is ignored.
|
||||||
|
// When unset, management controls whether push is enabled.
|
||||||
|
func (c *ClientMetrics) UpdatePushFromMgm(ctx context.Context, enabled bool) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isMetricsPushEnvSet() {
|
||||||
|
log.Debugf("ignoring management config, env var is explicitly set: %s", EnvMetricsPushEnabled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.pushMu.Lock()
|
||||||
|
defer c.pushMu.Unlock()
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
if c.push.Load() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("enabled metrics push by management")
|
||||||
|
c.startPushLocked(ctx, PushConfigFromEnv())
|
||||||
|
} else {
|
||||||
|
if c.push.Load() == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("disabled metrics push by management")
|
||||||
|
c.stopPushLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPushLocked starts push. Caller must hold pushMu.
|
||||||
|
func (c *ClientMetrics) startPushLocked(ctx context.Context, config PushConfig) {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
agentVersion := c.agentInfo.Version
|
agentVersion := c.agentInfo.Version
|
||||||
peerID := c.agentInfo.peerID
|
peerID := c.agentInfo.peerID
|
||||||
@@ -199,26 +259,23 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
c.pushCancel = cancel
|
c.pushCancel = cancel
|
||||||
|
c.push.Store(push)
|
||||||
|
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer c.wg.Done()
|
defer c.wg.Done()
|
||||||
push.Start(ctx)
|
push.Start(ctx)
|
||||||
|
c.push.CompareAndSwap(push, nil)
|
||||||
}()
|
}()
|
||||||
c.push = push
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientMetrics) StopPush() {
|
// stopPushLocked stops push. Caller must hold pushMu.
|
||||||
if c == nil {
|
func (c *ClientMetrics) stopPushLocked() {
|
||||||
return
|
if c.push.Load() == nil {
|
||||||
}
|
|
||||||
c.pushMu.Lock()
|
|
||||||
defer c.pushMu.Unlock()
|
|
||||||
if c.push == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.pushCancel()
|
c.pushCancel()
|
||||||
c.wg.Wait()
|
c.wg.Wait()
|
||||||
c.push = nil
|
c.push.Store(nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
|
|||||||
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
|
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -136,6 +137,39 @@ type Conn struct {
|
|||||||
// Connection stage timestamps for metrics
|
// Connection stage timestamps for metrics
|
||||||
metricsRecorder MetricsRecorder
|
metricsRecorder MetricsRecorder
|
||||||
metricsStages *MetricsStages
|
metricsStages *MetricsStages
|
||||||
|
|
||||||
|
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
|
||||||
|
// transport is up.
|
||||||
|
pendingFirstPacket []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
|
||||||
|
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
|
||||||
|
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
|
||||||
|
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
|
||||||
|
pkt := conn.pendingFirstPacket
|
||||||
|
if len(pkt) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case proxy != nil:
|
||||||
|
if err := proxy.InjectPacket(pkt); err != nil {
|
||||||
|
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case directConn != nil:
|
||||||
|
if _, err := directConn.Write(pkt); err != nil {
|
||||||
|
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
conn.Log.Debugf("no transport available to reinject captured first packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.pendingFirstPacket = nil
|
||||||
|
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -172,6 +206,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||||
|
return conn.open(engineCtx, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
|
||||||
|
// the real transport is established. The packet is retained only on a successful open.
|
||||||
|
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
|
||||||
|
return conn.open(engineCtx, firstPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -227,6 +271,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
defer conn.wg.Done()
|
defer conn.wg.Done()
|
||||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||||
}()
|
}()
|
||||||
|
if len(firstPacket) > 0 {
|
||||||
|
conn.pendingFirstPacket = slices.Clone(firstPacket)
|
||||||
|
}
|
||||||
conn.opened = true
|
conn.opened = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -423,6 +470,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.wgProxyRelay.RedirectAs(ep)
|
conn.wgProxyRelay.RedirectAs(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo, updateTime)
|
conn.updateIceState(iceConnInfo, updateTime)
|
||||||
@@ -546,6 +595,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
|
conn.injectPendingFirstPacket(wgProxy, nil)
|
||||||
|
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
conn.statusRelay.SetConnected()
|
conn.statusRelay.SetConnected()
|
||||||
@@ -752,15 +803,17 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||||
if !conn.wgWatcher.IsEnabled() {
|
if !conn.wgWatcher.PrepareInitialHandshake() {
|
||||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
return
|
||||||
conn.wgWatcherCancel = wgWatcherCancel
|
|
||||||
conn.wgWatcherWg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer conn.wgWatcherWg.Done()
|
|
||||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||||
|
conn.wgWatcherCancel = wgWatcherCancel
|
||||||
|
conn.wgWatcherWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer conn.wgWatcherWg.Done()
|
||||||
|
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) disableWgWatcherIfNeeded() {
|
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||||
|
|||||||
@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
offer := h.buildOfferAnswer()
|
offer := h.buildOfferAnswer()
|
||||||
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
|
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
|
||||||
|
|
||||||
return h.signaler.SignalOffer(offer, h.config.Key)
|
return h.signaler.SignalOffer(offer, h.config.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) sendAnswer() error {
|
func (h *Handshaker) sendAnswer() error {
|
||||||
answer := h.buildOfferAnswer()
|
answer := h.buildOfferAnswer()
|
||||||
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
|
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
|
||||||
|
|
||||||
return h.signaler.SignalAnswer(answer, h.config.Key)
|
return h.signaler.SignalAnswer(answer, h.config.Key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
|||||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||||
type Status struct {
|
type Status struct {
|
||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
|
muxRelays sync.RWMutex
|
||||||
peers map[string]State
|
peers map[string]State
|
||||||
ipToKey map[string]string
|
ipToKey map[string]string
|
||||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||||
@@ -244,8 +245,8 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
||||||
d.mux.Lock()
|
d.muxRelays.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.muxRelays.Unlock()
|
||||||
d.relayMgr = manager
|
d.relayMgr = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -906,8 +907,8 @@ func (d *Status) MarkSignalConnected() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||||
d.mux.Lock()
|
d.muxRelays.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.muxRelays.Unlock()
|
||||||
d.relayStates = relayResults
|
d.relayStates = relayResults
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1018,24 +1019,26 @@ func (d *Status) GetSignalState() SignalState {
|
|||||||
|
|
||||||
// GetRelayStates returns the stun/turn/permanent relay states
|
// GetRelayStates returns the stun/turn/permanent relay states
|
||||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||||
d.mux.RLock()
|
d.muxRelays.RLock()
|
||||||
defer d.mux.RUnlock()
|
|
||||||
if d.relayMgr == nil {
|
if d.relayMgr == nil {
|
||||||
return d.relayStates
|
defer d.muxRelays.RUnlock()
|
||||||
|
return slices.Clone(d.relayStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayMgr := d.relayMgr
|
||||||
// extend the list of stun, turn servers with the relay server connections
|
// extend the list of stun, turn servers with the relay server connections
|
||||||
relayStates := slices.Clone(d.relayStates)
|
relayStates := slices.Clone(d.relayStates)
|
||||||
|
d.muxRelays.RUnlock()
|
||||||
|
|
||||||
states := d.relayMgr.RelayStates()
|
states := relayMgr.RelayStates()
|
||||||
if len(states) == 0 {
|
if len(states) == 0 {
|
||||||
// no relay connection tracked yet; surface configured servers as
|
// no relay connection tracked yet; surface configured servers as
|
||||||
// unavailable with the real reconnect error when known
|
// unavailable with the real reconnect error when known
|
||||||
err := relayClient.ErrRelayClientNotConnected
|
err := relayClient.ErrRelayClientNotConnected
|
||||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
if connErr := relayMgr.RelayConnectError(); connErr != nil {
|
||||||
err = connErr
|
err = connErr
|
||||||
}
|
}
|
||||||
for _, r := range d.relayMgr.ServerURLs() {
|
for _, r := range relayMgr.ServerURLs() {
|
||||||
relayStates = append(relayStates, relay.ProbeResult{
|
relayStates = append(relayStates, relay.ProbeResult{
|
||||||
URI: r,
|
URI: r,
|
||||||
Err: err,
|
Err: err,
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ type WGWatcher struct {
|
|||||||
stateDump *stateDump
|
stateDump *stateDump
|
||||||
|
|
||||||
enabled bool
|
enabled bool
|
||||||
muEnabled sync.RWMutex
|
muEnabled sync.Mutex
|
||||||
|
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
|
||||||
|
initialHandshake time.Time
|
||||||
|
|
||||||
resetCh chan struct{}
|
resetCh chan struct{}
|
||||||
}
|
}
|
||||||
@@ -46,38 +48,38 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
|
||||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
// handshake time. It must be called before the peer is (re)configured on the WireGuard
|
||||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
// interface, so the captured baseline reflects the state prior to this connection attempt
|
||||||
|
// instead of racing with that configuration. Returns ok=false if the watcher is already
|
||||||
|
// running, in which case EnableWgWatcher must not be called.
|
||||||
|
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
|
||||||
w.muEnabled.Lock()
|
w.muEnabled.Lock()
|
||||||
if w.enabled {
|
if w.enabled {
|
||||||
w.muEnabled.Unlock()
|
w.muEnabled.Unlock()
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
w.log.Debugf("enable WireGuard watcher")
|
w.log.Debugf("enable WireGuard watcher")
|
||||||
w.enabled = true
|
w.enabled = true
|
||||||
w.muEnabled.Unlock()
|
w.muEnabled.Unlock()
|
||||||
|
|
||||||
initialHandshake, err := w.wgState()
|
handshake, _ := w.wgState()
|
||||||
if err != nil {
|
w.initialHandshake = handshake
|
||||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake)
|
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
|
||||||
|
// PrepareInitialHandshake. The watcher runs until ctx is cancelled. Caller is responsible
|
||||||
|
// for context lifecycle management.
|
||||||
|
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||||
|
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
|
||||||
|
|
||||||
w.muEnabled.Lock()
|
w.muEnabled.Lock()
|
||||||
w.enabled = false
|
w.enabled = false
|
||||||
w.muEnabled.Unlock()
|
w.muEnabled.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
|
||||||
func (w *WGWatcher) IsEnabled() bool {
|
|
||||||
w.muEnabled.RLock()
|
|
||||||
defer w.muEnabled.RUnlock()
|
|
||||||
return w.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||||
func (w *WGWatcher) Reset() {
|
func (w *WGWatcher) Reset() {
|
||||||
@@ -101,13 +103,16 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
|||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
handshake, ok := w.handshakeCheck(lastHandshake)
|
handshake, ok := w.handshakeCheck(lastHandshake)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
onDisconnectedFn()
|
onDisconnectedFn()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if lastHandshake.IsZero() {
|
if lastHandshake.IsZero() {
|
||||||
elapsed := calcElapsed(enabledTime, *handshake)
|
elapsed := calcElapsed(enabledTime, *handshake)
|
||||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||||
if onHandshakeSuccessFn != nil {
|
if onHandshakeSuccessFn != nil && ctx.Err() == nil {
|
||||||
onHandshakeSuccessFn(*handshake)
|
onHandshakeSuccessFn(*handshake)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
)
|
)
|
||||||
@@ -34,6 +35,9 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
ok := watcher.PrepareInitialHandshake()
|
||||||
|
require.True(t, ok, "watcher should not be enabled yet")
|
||||||
|
|
||||||
onDisconnected := make(chan struct{}, 1)
|
onDisconnected := make(chan struct{}, 1)
|
||||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||||
mlog.Infof("onDisconnectedFn")
|
mlog.Infof("onDisconnectedFn")
|
||||||
@@ -62,6 +66,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
|||||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
ok := watcher.PrepareInitialHandshake()
|
||||||
|
require.True(t, ok, "watcher should not be enabled yet")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -76,6 +83,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
|||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
ok = watcher.PrepareInitialHandshake()
|
||||||
|
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
|
||||||
|
|
||||||
onDisconnected := make(chan struct{}, 1)
|
onDisconnected := make(chan struct{}, 1)
|
||||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||||
onDisconnected <- struct{}{}
|
onDisconnected <- struct{}{}
|
||||||
|
|||||||
@@ -88,11 +88,24 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// this can be blocked because of the connect open limiter semaphore
|
|
||||||
if err := p.Open(ctx); err != nil {
|
if err := p.Open(ctx); err != nil {
|
||||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
|
||||||
|
// reinjected once the real transport is established.
|
||||||
|
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
|
||||||
|
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) PeerConnIdle(pubKey string) {
|
func (s *Store) PeerConnIdle(pubKey string) {
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
if input.NetworkMonitor != nil && (config.NetworkMonitor == nil || *input.NetworkMonitor != *config.NetworkMonitor) {
|
||||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||||
config.NetworkMonitor = input.NetworkMonitor
|
config.NetworkMonitor = input.NetworkMonitor
|
||||||
updated = true
|
updated = true
|
||||||
@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
|
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
|
||||||
if *input.ServerSSHAllowed {
|
if *input.ServerSSHAllowed {
|
||||||
log.Infof("enabling SSH server")
|
log.Infof("enabling SSH server")
|
||||||
} else {
|
} else {
|
||||||
@@ -454,7 +454,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
if input.EnableSSHRoot != nil && (config.EnableSSHRoot == nil || *input.EnableSSHRoot != *config.EnableSSHRoot) {
|
||||||
if *input.EnableSSHRoot {
|
if *input.EnableSSHRoot {
|
||||||
log.Infof("enabling SSH root login")
|
log.Infof("enabling SSH root login")
|
||||||
} else {
|
} else {
|
||||||
@@ -464,7 +464,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
if input.EnableSSHSFTP != nil && (config.EnableSSHSFTP == nil || *input.EnableSSHSFTP != *config.EnableSSHSFTP) {
|
||||||
if *input.EnableSSHSFTP {
|
if *input.EnableSSHSFTP {
|
||||||
log.Infof("enabling SSH SFTP subsystem")
|
log.Infof("enabling SSH SFTP subsystem")
|
||||||
} else {
|
} else {
|
||||||
@@ -474,7 +474,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
if input.EnableSSHLocalPortForwarding != nil && (config.EnableSSHLocalPortForwarding == nil || *input.EnableSSHLocalPortForwarding != *config.EnableSSHLocalPortForwarding) {
|
||||||
if *input.EnableSSHLocalPortForwarding {
|
if *input.EnableSSHLocalPortForwarding {
|
||||||
log.Infof("enabling SSH local port forwarding")
|
log.Infof("enabling SSH local port forwarding")
|
||||||
} else {
|
} else {
|
||||||
@@ -484,7 +484,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
if input.EnableSSHRemotePortForwarding != nil && (config.EnableSSHRemotePortForwarding == nil || *input.EnableSSHRemotePortForwarding != *config.EnableSSHRemotePortForwarding) {
|
||||||
if *input.EnableSSHRemotePortForwarding {
|
if *input.EnableSSHRemotePortForwarding {
|
||||||
log.Infof("enabling SSH remote port forwarding")
|
log.Infof("enabling SSH remote port forwarding")
|
||||||
} else {
|
} else {
|
||||||
@@ -494,7 +494,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
if input.DisableSSHAuth != nil && (config.DisableSSHAuth == nil || *input.DisableSSHAuth != *config.DisableSSHAuth) {
|
||||||
if *input.DisableSSHAuth {
|
if *input.DisableSSHAuth {
|
||||||
log.Infof("disabling SSH authentication")
|
log.Infof("disabling SSH authentication")
|
||||||
} else {
|
} else {
|
||||||
@@ -504,7 +504,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
if input.SSHJWTCacheTTL != nil && (config.SSHJWTCacheTTL == nil || *input.SSHJWTCacheTTL != *config.SSHJWTCacheTTL) {
|
||||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
updated = true
|
updated = true
|
||||||
@@ -587,7 +587,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
if input.DisableNotifications != nil && (config.DisableNotifications == nil || *input.DisableNotifications != *config.DisableNotifications) {
|
||||||
if *input.DisableNotifications {
|
if *input.DisableNotifications {
|
||||||
log.Infof("disabling notifications")
|
log.Infof("disabling notifications")
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -242,6 +242,35 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
|
||||||
|
// Configs written before ServerSSHAllowed was introduced lack the field and
|
||||||
|
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
|
||||||
|
// apply the value instead of panicking on a nil pointer dereference.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input *bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"enable", util.True(), true},
|
||||||
|
{"disable", util.False(), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
|
||||||
|
|
||||||
|
config, err := UpdateConfig(ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
ServerSSHAllowed: tt.input,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
|
||||||
|
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateOldManagementURL(t *testing.T) {
|
func TestUpdateOldManagementURL(t *testing.T) {
|
||||||
origProber := newMgmProber
|
origProber := newMgmProber
|
||||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||||
|
|||||||
@@ -226,12 +226,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// pass if non A/AAAA query
|
// All query types for an intercepted domain are forwarded to the peer's
|
||||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
// DNS forwarder, which owns the name. Falling through to the system
|
||||||
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
|
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
|
||||||
return
|
// for, poisoning the whole name (including the A/AAAA records the route
|
||||||
}
|
// does serve). The forwarder answers NODATA for types it cannot resolve.
|
||||||
|
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
peerKey := d.currentPeerKey
|
peerKey := d.currentPeerKey
|
||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
@@ -293,19 +292,6 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// continueToNextHandler signals the handler chain to try the next handler
|
|
||||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
|
||||||
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
|
||||||
resp.SetRcode(r, dns.RcodeNameError)
|
|
||||||
// Set Zero bit to signal handler chain to continue
|
|
||||||
resp.MsgHdr.Zero = true
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed writing DNS continue response: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
||||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||||
if !exists {
|
if !exists {
|
||||||
|
|||||||
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newExitNodeTestManager() *DefaultManager {
|
||||||
|
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
||||||
|
return &route.Route{
|
||||||
|
NetID: route.NetID(netID),
|
||||||
|
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
Peer: peer,
|
||||||
|
SkipAutoApply: skipAutoApply,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPickPreferredExitNode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info exitNodeInfo
|
||||||
|
want route.NetID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "persisted user selection wins over management",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"a", "b", "c"},
|
||||||
|
userSelected: []route.NetID{"b"},
|
||||||
|
selectedByManagement: []route.NetID{"a"},
|
||||||
|
},
|
||||||
|
want: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple user-selected self-heal to deterministic min",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"a", "b", "c"},
|
||||||
|
userSelected: []route.NetID{"c", "a"},
|
||||||
|
},
|
||||||
|
want: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit opt-out keeps none",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"a", "b"},
|
||||||
|
userDeselected: []route.NetID{"a", "b"},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fresh defaults to management auto-apply pick",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"a", "b", "c"},
|
||||||
|
selectedByManagement: []route.NetID{"b"},
|
||||||
|
},
|
||||||
|
want: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no user pick and no management auto-apply selects none",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"c", "a", "b"},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user-deselect does not block a management auto-apply sibling",
|
||||||
|
info: exitNodeInfo{
|
||||||
|
allIDs: []route.NetID{"a", "b"},
|
||||||
|
userDeselected: []route.NetID{"a"},
|
||||||
|
selectedByManagement: []route.NetID{"b"},
|
||||||
|
},
|
||||||
|
want: "b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceSingleExitNode(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
all := []route.NetID{"a", "b", "c"}
|
||||||
|
|
||||||
|
m.enforceSingleExitNode("b", all)
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
||||||
|
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
||||||
|
|
||||||
|
// Switching the preferred node moves the single selection.
|
||||||
|
m.enforceSingleExitNode("c", all)
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
||||||
|
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
||||||
|
|
||||||
|
// Empty preferred turns every exit node off.
|
||||||
|
m.enforceSingleExitNode("", all)
|
||||||
|
for _, id := range all {
|
||||||
|
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
m.routeSelector.DeselectAllRoutes()
|
||||||
|
|
||||||
|
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
||||||
|
|
||||||
|
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||||
|
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||||
|
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
||||||
|
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateRouteSelectorFromManagement(routes)
|
||||||
|
|
||||||
|
// Exactly one exit node (the deterministic first) is selected.
|
||||||
|
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
||||||
|
// Non-exit routes are left at their default-on state.
|
||||||
|
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||||
|
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||||
|
}
|
||||||
|
all := []route.NetID{"exitA", "exitB"}
|
||||||
|
|
||||||
|
// Simulate the state the runtime select path leaves behind: exactly one
|
||||||
|
// exit node explicitly selected, its sibling deselected.
|
||||||
|
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
||||||
|
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
||||||
|
|
||||||
|
m.updateRouteSelectorFromManagement(routes)
|
||||||
|
|
||||||
|
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||||
|
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||||
|
}
|
||||||
|
all := []route.NetID{"exitA", "exitB"}
|
||||||
|
|
||||||
|
// User deselected exit nodes and selected none.
|
||||||
|
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
||||||
|
|
||||||
|
m.updateRouteSelectorFromManagement(routes)
|
||||||
|
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
||||||
|
m := newExitNodeTestManager()
|
||||||
|
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
||||||
|
// auto-activation, so none should be selected until the user picks one.
|
||||||
|
routes := route.HAMap{
|
||||||
|
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
||||||
|
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateRouteSelectorFromManagement(routes)
|
||||||
|
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
||||||
|
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
||||||
|
}
|
||||||
@@ -701,7 +701,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
|||||||
return ips
|
return ips
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
||||||
|
// network map: it keeps at most one exit node selected — the user's persisted
|
||||||
|
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
||||||
|
// else none. We never auto-activate an exit node the map doesn't request; it
|
||||||
|
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
||||||
|
// RouteSelector stores routes with default-on semantics, so without this every
|
||||||
|
// available exit node would report selected at once.
|
||||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||||
|
|
||||||
@@ -712,13 +718,14 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
info := m.collectExitNodeInfo(clientRoutes)
|
||||||
if len(exitNodeInfo.allIDs) == 0 {
|
if len(info.allIDs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateExitNodeSelections(exitNodeInfo)
|
preferred := pickPreferredExitNode(info)
|
||||||
m.logExitNodeUpdate(exitNodeInfo)
|
m.enforceSingleExitNode(preferred, info.allIDs)
|
||||||
|
m.logExitNodeUpdate(info, preferred)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||||
@@ -746,6 +753,10 @@ type exitNodeInfo struct {
|
|||||||
userDeselected []route.NetID
|
userDeselected []route.NetID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
||||||
|
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
||||||
|
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
||||||
|
// — counting it separately would double-count the pair.
|
||||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||||
var info exitNodeInfo
|
var info exitNodeInfo
|
||||||
|
|
||||||
@@ -755,6 +766,9 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
|||||||
}
|
}
|
||||||
|
|
||||||
netID := haID.NetID()
|
netID := haID.NetID()
|
||||||
|
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
info.allIDs = append(info.allIDs, netID)
|
info.allIDs = append(info.allIDs, netID)
|
||||||
|
|
||||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||||
@@ -791,45 +805,52 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
||||||
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
// - a persisted user selection wins (deterministic if several survive from
|
||||||
m.deselectExitNodes(routesToDeselect)
|
// legacy state, so the set self-heals down to one);
|
||||||
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
// - otherwise activate only what management marks for auto-apply
|
||||||
|
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
||||||
|
//
|
||||||
|
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
||||||
|
// route the map doesn't auto-apply stays off until the user selects it.
|
||||||
|
// info.userDeselected is informational only: an explicit deselect simply keeps
|
||||||
|
// that route out of both lists above, so it can't be picked.
|
||||||
|
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
||||||
|
if len(info.userSelected) > 0 {
|
||||||
|
return minNetID(info.userSelected)
|
||||||
|
}
|
||||||
|
if len(info.selectedByManagement) > 0 {
|
||||||
|
return minNetID(info.selectedByManagement)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
||||||
var routesToDeselect []route.NetID
|
// available exit node is deselected and preferred (if any) is selected, without
|
||||||
for _, netID := range allIDs {
|
// disturbing non-exit route selections. The whole reconciliation runs under a
|
||||||
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
|
||||||
routesToDeselect = append(routesToDeselect, netID)
|
// cannot interleave and get undone; a global deselect-all is left untouched so
|
||||||
|
// the user's "all off" stays in effect.
|
||||||
|
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
||||||
|
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
||||||
|
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
||||||
|
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
||||||
|
}
|
||||||
|
|
||||||
|
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
||||||
|
// default pick that stays stable across restarts.
|
||||||
|
func minNetID(ids []route.NetID) route.NetID {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
best := ids[0]
|
||||||
|
for _, id := range ids[1:] {
|
||||||
|
if id < best {
|
||||||
|
best = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routesToDeselect
|
return best
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
|
||||||
if len(routesToDeselect) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Failed to deselect exit nodes: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
|
||||||
if len(selectedByManagement) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Failed to select exit nodes: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
|
||||||
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
|
||||||
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEntryExists(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||||
|
|
||||||
|
content := []string{
|
||||||
|
"1000 reserved",
|
||||||
|
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||||
|
"9999 other_table",
|
||||||
|
}
|
||||||
|
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||||
|
|
||||||
|
file, err := os.Open(tempFilePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
assert.NoError(t, file.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id int
|
||||||
|
shouldExist bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ExistsWithNetbirdPrefix",
|
||||||
|
id: 7120,
|
||||||
|
shouldExist: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ExistsWithDifferentName",
|
||||||
|
id: 1000,
|
||||||
|
shouldExist: true,
|
||||||
|
err: ErrTableIDExists,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DoesNotExist",
|
||||||
|
id: 1234,
|
||||||
|
shouldExist: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
exists, err := entryExists(file, tc.id)
|
||||||
|
if tc.err != nil {
|
||||||
|
assert.ErrorIs(t, err, tc.err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.shouldExist, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCases = append(testCases, []testCase{
|
||||||
|
{
|
||||||
|
name: "To more specific route without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRoutes(t *testing.T) {
|
||||||
|
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
var intf *net.Interface
|
||||||
|
var nexthop Nexthop
|
||||||
|
|
||||||
|
_, intf = setupDummyInterface(t)
|
||||||
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
|
r := New(nil, nil)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||||
|
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||||
|
|
||||||
|
for i := 0; i < 1024; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ip netip.Addr) {
|
||||||
|
defer wg.Done()
|
||||||
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
|
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||||
|
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}(baseIP)
|
||||||
|
baseIP = baseIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||||
|
require.NoError(t, err, "Failed to create loopback alias")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||||
|
})
|
||||||
|
|
||||||
|
return intf
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||||
|
require.NoError(t, err, "Failed to parse prefix")
|
||||||
|
|
||||||
|
netIntf, err := net.InterfaceByName(intf)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
|
r := New(nil, nil)
|
||||||
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := r.removeFromRouteTable(prefix, nexthop)
|
||||||
|
assert.NoError(t, err, "Failed to remove route from table")
|
||||||
|
})
|
||||||
|
|
||||||
|
return intf
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var originalNexthop net.IP
|
||||||
|
if dstCIDR == "0.0.0.0/0" {
|
||||||
|
var err error
|
||||||
|
originalNexthop, err = fetchOriginalGateway()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to fetch original gateway: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||||
|
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if originalNexthop != nil {
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||||
|
assert.NoError(t, err, "Failed to restore original route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||||
|
require.NoError(t, err, "Failed to add route")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove route")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchOriginalGateway() (net.IP, error) {
|
||||||
|
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, fmt.Errorf("gateway not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(matches[1]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||||
|
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||||
|
|
||||||
|
tunName := strings.TrimSpace(string(output))
|
||||||
|
|
||||||
|
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||||
|
|
||||||
|
intf, err := net.InterfaceByName(tunName)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||||
|
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||||
|
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||||
|
|
||||||
|
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||||
|
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||||
|
}
|
||||||
@@ -3,79 +3,24 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files in this package compile.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedVPNint = "utun100"
|
var expectedVPNint = "utun100"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedExternalInt = "lo0"
|
var expectedExternalInt = "lo0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
var expectedInternalInt = "lo0"
|
var expectedInternalInt = "lo0"
|
||||||
|
|
||||||
func init() {
|
|
||||||
testCases = append(testCases, []testCase{
|
|
||||||
{
|
|
||||||
name: "To more specific route without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrentRoutes(t *testing.T) {
|
|
||||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
|
||||||
|
|
||||||
var intf *net.Interface
|
|
||||||
var nexthop Nexthop
|
|
||||||
|
|
||||||
_, intf = setupDummyInterface(t)
|
|
||||||
nexthop = Nexthop{netip.Addr{}, intf}
|
|
||||||
|
|
||||||
r := New(nil, nil)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 1024; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(ip netip.Addr) {
|
|
||||||
defer wg.Done()
|
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
|
||||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
|
||||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
|
||||||
}
|
|
||||||
}(baseIP)
|
|
||||||
baseIP = baseIP.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
|
||||||
|
|
||||||
for i := 0; i < 1024; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(ip netip.Addr) {
|
|
||||||
defer wg.Done()
|
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
|
||||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
|
||||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
|
||||||
}
|
|
||||||
}(baseIP)
|
|
||||||
baseIP = baseIP.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBits(t *testing.T) {
|
func TestBits(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -122,122 +67,3 @@ func TestBits(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
|
||||||
require.NoError(t, err, "Failed to create loopback alias")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
|
||||||
})
|
|
||||||
|
|
||||||
return intf
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
|
||||||
require.NoError(t, err, "Failed to parse prefix")
|
|
||||||
|
|
||||||
netIntf, err := net.InterfaceByName(intf)
|
|
||||||
require.NoError(t, err, "Failed to get interface by name")
|
|
||||||
|
|
||||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
|
||||||
|
|
||||||
r := New(nil, nil)
|
|
||||||
err = r.addToRouteTable(prefix, nexthop)
|
|
||||||
require.NoError(t, err, "Failed to add route to table")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := r.removeFromRouteTable(prefix, nexthop)
|
|
||||||
assert.NoError(t, err, "Failed to remove route from table")
|
|
||||||
})
|
|
||||||
|
|
||||||
return intf
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
var originalNexthop net.IP
|
|
||||||
if dstCIDR == "0.0.0.0/0" {
|
|
||||||
var err error
|
|
||||||
originalNexthop, err = fetchOriginalGateway()
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Failed to fetch original gateway: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
|
||||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if originalNexthop != nil {
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
|
||||||
assert.NoError(t, err, "Failed to restore original route")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
|
||||||
require.NoError(t, err, "Failed to add route")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove route")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOriginalGateway() (net.IP, error) {
|
|
||||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return nil, fmt.Errorf("gateway not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.ParseIP(matches[1]), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
|
||||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
|
||||||
|
|
||||||
tunName := strings.TrimSpace(string(output))
|
|
||||||
|
|
||||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
|
||||||
|
|
||||||
intf, err := net.InterfaceByName(tunName)
|
|
||||||
require.NoError(t, err, "Failed to get interface by name")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
|
||||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
|
||||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
|
||||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialer is shared by the per-platform routing test cases. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files compile on every platform.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
type dialer interface {
|
||||||
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !android && !ios
|
//go:build !android && !ios && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
@@ -26,11 +26,6 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
|
||||||
Dial(network, address string) (net.Conn, error)
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddVPNRoute(t *testing.T) {
|
func TestAddVPNRoute(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -515,125 +510,3 @@ func setupTestEnv(t *testing.T) {
|
|||||||
// unique route in vpn table
|
// unique route in vpn table
|
||||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsVpnRoute(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
addr string
|
|
||||||
vpnRoutes []string
|
|
||||||
localRoutes []string
|
|
||||||
expectedVpn bool
|
|
||||||
expectedPrefix netip.Prefix
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Match in VPN routes",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Match in local routes",
|
|
||||||
addr: "10.1.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No match",
|
|
||||||
addr: "172.16.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.Prefix{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Default route ignored",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Default route matches but ignored",
|
|
||||||
addr: "172.16.1.1",
|
|
||||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"10.0.0.0/8"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.Prefix{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match local",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16"},
|
|
||||||
localRoutes: []string{"192.168.1.0/24"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match local multiple",
|
|
||||||
addr: "192.168.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
|
||||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match vpn",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"192.168.0.0/16"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Longest prefix match vpn multiple",
|
|
||||||
addr: "192.168.0.1",
|
|
||||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
|
||||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
|
||||||
expectedVpn: true,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Duplicate prefix in both",
|
|
||||||
addr: "192.168.1.1",
|
|
||||||
vpnRoutes: []string{"192.168.1.0/24"},
|
|
||||||
localRoutes: []string{"192.168.1.0/24"},
|
|
||||||
expectedVpn: false,
|
|
||||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr, err := netip.ParseAddr(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var vpnRoutes, localRoutes []netip.Prefix
|
|
||||||
for _, route := range tt.vpnRoutes {
|
|
||||||
prefix, err := netip.ParsePrefix(route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
|
||||||
}
|
|
||||||
vpnRoutes = append(vpnRoutes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range tt.localRoutes {
|
|
||||||
prefix, err := netip.ParsePrefix(route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
|
||||||
}
|
|
||||||
localRoutes = append(localRoutes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
|
||||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
|
||||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsVpnRoute(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
vpnRoutes []string
|
||||||
|
localRoutes []string
|
||||||
|
expectedVpn bool
|
||||||
|
expectedPrefix netip.Prefix
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Match in VPN routes",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Match in local routes",
|
||||||
|
addr: "10.1.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No match",
|
||||||
|
addr: "172.16.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.Prefix{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default route ignored",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default route matches but ignored",
|
||||||
|
addr: "172.16.1.1",
|
||||||
|
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"10.0.0.0/8"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.Prefix{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match local",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16"},
|
||||||
|
localRoutes: []string{"192.168.1.0/24"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match local multiple",
|
||||||
|
addr: "192.168.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||||
|
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match vpn",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"192.168.0.0/16"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Longest prefix match vpn multiple",
|
||||||
|
addr: "192.168.0.1",
|
||||||
|
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||||
|
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||||
|
expectedVpn: true,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Duplicate prefix in both",
|
||||||
|
addr: "192.168.1.1",
|
||||||
|
vpnRoutes: []string{"192.168.1.0/24"},
|
||||||
|
localRoutes: []string{"192.168.1.0/24"},
|
||||||
|
expectedVpn: false,
|
||||||
|
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, err := netip.ParseAddr(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vpnRoutes, localRoutes []netip.Prefix
|
||||||
|
for _, route := range tt.vpnRoutes {
|
||||||
|
prefix, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||||
|
}
|
||||||
|
vpnRoutes = append(vpnRoutes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range tt.localRoutes {
|
||||||
|
prefix, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||||
|
}
|
||||||
|
localRoutes = append(localRoutes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||||
|
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||||
|
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,10 @@
|
|||||||
//go:build !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -18,10 +15,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
)
|
)
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
|
||||||
var expectedExternalInt = "dummyext0"
|
|
||||||
var expectedInternalInt = "dummyint0"
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
testCases = append(testCases, []testCase{
|
testCases = append(testCases, []testCase{
|
||||||
{
|
{
|
||||||
@@ -33,62 +26,6 @@ func init() {
|
|||||||
}...)
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEntryExists(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
|
||||||
|
|
||||||
content := []string{
|
|
||||||
"1000 reserved",
|
|
||||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
|
||||||
"9999 other_table",
|
|
||||||
}
|
|
||||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
|
||||||
|
|
||||||
file, err := os.Open(tempFilePath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
assert.NoError(t, file.Close())
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id int
|
|
||||||
shouldExist bool
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ExistsWithNetbirdPrefix",
|
|
||||||
id: 7120,
|
|
||||||
shouldExist: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ExistsWithDifferentName",
|
|
||||||
id: 1000,
|
|
||||||
shouldExist: true,
|
|
||||||
err: ErrTableIDExists,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DoesNotExist",
|
|
||||||
id: 1234,
|
|
||||||
shouldExist: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
exists, err := entryExists(file, tc.id)
|
|
||||||
if tc.err != nil {
|
|
||||||
assert.ErrorIs(t, err, tc.err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.shouldExist, exists)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// Interface names used by the shared routing test fixtures. Kept untagged (no
|
||||||
|
// privileged build tag) so the non-privileged test files in this package compile.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedVPNint = "wgtest0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedExternalInt = "dummyext0"
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var expectedInternalInt = "dummyint0"
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
|
||||||
|
// per-platform init() appenders) consume these; they live here so the unprivileged
|
||||||
|
// BSD/darwin test files compile without the privileged build tag.
|
||||||
|
|
||||||
|
type PacketExpectation struct {
|
||||||
|
SrcIP net.IP
|
||||||
|
DstIP net.IP
|
||||||
|
SrcPort int
|
||||||
|
DstPort int
|
||||||
|
UDP bool
|
||||||
|
TCP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
expectedInterface string
|
||||||
|
dialer dialer
|
||||||
|
expectedPacket PacketExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
var testCases = []testCase{
|
||||||
|
{
|
||||||
|
name: "To external host without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To external host with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||||
|
expectedInterface: expectedInternalInt,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
|
expectedInterface: expectedExternalInt,
|
||||||
|
dialer: nbnet.NewDialer(),
|
||||||
|
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To unique vpn route without custom dialer via vpn",
|
||||||
|
expectedInterface: expectedVPNint,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
|
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||||
|
return PacketExpectation{
|
||||||
|
SrcIP: net.ParseIP(srcIP),
|
||||||
|
DstIP: net.ParseIP(dstIP),
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
UDP: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
@@ -20,63 +20,6 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketExpectation struct {
|
|
||||||
SrcIP net.IP
|
|
||||||
DstIP net.IP
|
|
||||||
SrcPort int
|
|
||||||
DstPort int
|
|
||||||
UDP bool
|
|
||||||
TCP bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type testCase struct {
|
|
||||||
name string
|
|
||||||
expectedInterface string
|
|
||||||
dialer dialer
|
|
||||||
expectedPacket PacketExpectation
|
|
||||||
}
|
|
||||||
|
|
||||||
var testCases = []testCase{
|
|
||||||
{
|
|
||||||
name: "To external host without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To external host with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To unique vpn route with custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To unique vpn route without custom dialer via vpn",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
nbnet.Init()
|
nbnet.Init()
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -102,16 +45,6 @@ func TestRouting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
|
||||||
return PacketExpectation{
|
|
||||||
SrcIP: net.ParseIP(srcIP),
|
|
||||||
DstIP: net.ParseIP(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
UDP: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build windows && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||||
// without v6 connectivity. If a default already exists it is left alone.
|
// without v6 connectivity. If a default already exists it is left alone.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android && privileged
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
|
||||||
|
|
||||||
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
|
||||||
// interface so route lookups for global IPv6 prefixes resolve in environments
|
// interface so route lookups for global IPv6 prefixes resolve in environments
|
||||||
// without v6 connectivity. If a default already exists it is left alone.
|
// without v6 connectivity. If a default already exists it is left alone.
|
||||||
|
//
|
||||||
|
//nolint:unused // consumed by the privileged-tagged routing tests
|
||||||
func ensureIPv6DefaultRoute(t *testing.T) {
|
func ensureIPv6DefaultRoute(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,38 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
|||||||
clear(rs.selectedRoutes)
|
clear(rs.selectedRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
// SetExclusiveExitNode atomically makes preferred the only selected exit node
|
||||||
|
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
|
||||||
|
// non-empty) is selected, all under a single lock. Holding the lock across the
|
||||||
|
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
|
||||||
|
// between the deselect and select steps and being silently undone. A global
|
||||||
|
// deselect-all is left untouched so the user's "all off" stays in effect;
|
||||||
|
// non-exit routes are never referenced, so their selection is preserved.
|
||||||
|
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
|
if rs.deselectAll {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range exitIDs {
|
||||||
|
if id == preferred {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rs.deselectedRoutes[id] = struct{}{}
|
||||||
|
delete(rs.selectedRoutes, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preferred != "" {
|
||||||
|
delete(rs.deselectedRoutes, preferred)
|
||||||
|
rs.selectedRoutes[preferred] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
|
||||||
|
// user explicitly disabled every route. Callers enforcing per-route invariants
|
||||||
|
// (e.g. single exit node) should leave the selection untouched when it is.
|
||||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|||||||
235
client/server/server_privileged_test.go
Normal file
235
client/server/server_privileged_test.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kaep = keepalive.EnforcementPolicy{
|
||||||
|
MinTime: 15 * time.Second,
|
||||||
|
PermitWithoutStream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
kasp = keepalive.ServerParameters{
|
||||||
|
MaxConnectionIdle: 15 * time.Second,
|
||||||
|
MaxConnectionAgeGrace: 5 * time.Second,
|
||||||
|
Time: 5 * time.Second,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
||||||
|
// we will use a management server started via to simulate the server and capture the number of retries
|
||||||
|
func TestConnectWithRetryRuns(t *testing.T) {
|
||||||
|
// start the signal server
|
||||||
|
_, signalAddr, err := startSignal(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start signal server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
counter := 0
|
||||||
|
// start the management server
|
||||||
|
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start management server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
||||||
|
defer cancel()
|
||||||
|
// create new server
|
||||||
|
ic := profilemanager.ConfigInput{
|
||||||
|
ManagementURL: "http://" + mgmtAddr,
|
||||||
|
ConfigPath: t.TempDir() + "/test-profile.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pm := profilemanager.ServiceManager{}
|
||||||
|
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||||
|
ID: "test-profile",
|
||||||
|
Username: currUser.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(ctx, "debug", "", false, false, false, false)
|
||||||
|
|
||||||
|
s.config = config
|
||||||
|
|
||||||
|
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
||||||
|
t.Setenv(retryInitialIntervalVar, "1s")
|
||||||
|
t.Setenv(maxRetryIntervalVar, "2s")
|
||||||
|
t.Setenv(maxRetryTimeVar, "5s")
|
||||||
|
t.Setenv(retryMultiplierVar, "1")
|
||||||
|
|
||||||
|
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||||
|
if counter < 3 {
|
||||||
|
t.Fatalf("expected counter > 2, got %d", counter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockServer struct {
|
||||||
|
mgmtProto.ManagementServiceServer
|
||||||
|
counter *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||||
|
*m.counter++
|
||||||
|
return m.ManagementServiceServer.Login(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
dataDir := t.TempDir()
|
||||||
|
|
||||||
|
config := &config.Config{
|
||||||
|
Stuns: []*config.Host{},
|
||||||
|
TURNConfig: &config.TURNConfig{},
|
||||||
|
Signal: &config.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: signalAddr,
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
|
||||||
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||||
|
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
|
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
|
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
mock := &mockServer{
|
||||||
|
ManagementServiceServer: mgmtServer,
|
||||||
|
counter: counter,
|
||||||
|
}
|
||||||
|
mgmtProto.RegisterManagementServiceServer(s, mock)
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
proto.RegisterSignalExchangeServer(s, srv)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err = s.Serve(lis); err != nil {
|
||||||
|
log.Fatalf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s, lis.Addr().String(), nil
|
||||||
|
}
|
||||||
@@ -2,124 +2,22 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
kaep = keepalive.EnforcementPolicy{
|
|
||||||
MinTime: 15 * time.Second,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
kasp = keepalive.ServerParameters{
|
|
||||||
MaxConnectionIdle: 15 * time.Second,
|
|
||||||
MaxConnectionAgeGrace: 5 * time.Second,
|
|
||||||
Time: 5 * time.Second,
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
|
|
||||||
// we will use a management server started via to simulate the server and capture the number of retries
|
|
||||||
func TestConnectWithRetryRuns(t *testing.T) {
|
|
||||||
// start the signal server
|
|
||||||
_, signalAddr, err := startSignal(t)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to start signal server: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
counter := 0
|
|
||||||
// start the management server
|
|
||||||
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to start management server: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
|
||||||
|
|
||||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
|
|
||||||
defer cancel()
|
|
||||||
// create new server
|
|
||||||
ic := profilemanager.ConfigInput{
|
|
||||||
ManagementURL: "http://" + mgmtAddr,
|
|
||||||
ConfigPath: t.TempDir() + "/test-profile.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := profilemanager.UpdateOrCreateConfig(ic)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
currUser, err := user.Current()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
pm := profilemanager.ServiceManager{}
|
|
||||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
|
||||||
ID: "test-profile",
|
|
||||||
Username: currUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to set active profile state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := New(ctx, "debug", "", false, false, false, false)
|
|
||||||
|
|
||||||
s.config = config
|
|
||||||
|
|
||||||
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
|
|
||||||
t.Setenv(retryInitialIntervalVar, "1s")
|
|
||||||
t.Setenv(maxRetryIntervalVar, "2s")
|
|
||||||
t.Setenv(maxRetryTimeVar, "5s")
|
|
||||||
t.Setenv(retryMultiplierVar, "1")
|
|
||||||
|
|
||||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
|
||||||
if counter < 3 {
|
|
||||||
t.Fatalf("expected counter > 2, got %d", counter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_Up(t *testing.T) {
|
func TestServer_Up(t *testing.T) {
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||||
@@ -259,119 +157,3 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockServer struct {
|
|
||||||
mgmtProto.ManagementServiceServer
|
|
||||||
counter *int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
|
||||||
*m.counter++
|
|
||||||
return m.ManagementServiceServer.Login(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
dataDir := t.TempDir()
|
|
||||||
|
|
||||||
config := &config.Config{
|
|
||||||
Stuns: []*config.Host{},
|
|
||||||
TURNConfig: &config.TURNConfig{},
|
|
||||||
Signal: &config.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: signalAddr,
|
|
||||||
},
|
|
||||||
Datadir: dataDir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
t.Cleanup(ctrl.Finish)
|
|
||||||
|
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
|
||||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
|
||||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
|
||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
|
||||||
|
|
||||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
groupsManager := groups.NewManagerMock()
|
|
||||||
|
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
mock := &mockServer{
|
|
||||||
ManagementServiceServer: mgmtServer,
|
|
||||||
counter: counter,
|
|
||||||
}
|
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mock)
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed to listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
|
|
||||||
require.NoError(t, err)
|
|
||||||
proto.RegisterSignalExchangeServer(s, srv)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err = s.Serve(lis); err != nil {
|
|
||||||
log.Fatalf("failed to serve: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, lis.Addr().String(), nil
|
|
||||||
}
|
|
||||||
|
|||||||
118
client/ssh/client/client_privileged_test.go
Normal file
118
client/ssh/client/client_privileged_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
|
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||||
|
}
|
||||||
|
|
||||||
|
server, _, client := setupTestSSHServerAndClient(t)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||||
|
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(output), "hello")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||||
|
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("commands with flags work", func(t *testing.T) {
|
||||||
|
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||||
|
var testCmd string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
testCmd = "echo hello | Select-String notfound"
|
||||||
|
} else {
|
||||||
|
testCmd = "echo 'hello' | grep 'notfound'"
|
||||||
|
}
|
||||||
|
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||||
|
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("connection with short timeout", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Check for actual timeout-related errors rather than string matching
|
||||||
|
assert.True(t,
|
||||||
|
errors.Is(err, context.DeadlineExceeded) ||
|
||||||
|
errors.Is(err, context.Canceled) ||
|
||||||
|
strings.Contains(err.Error(), "timeout"),
|
||||||
|
"Expected timeout-related error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command execution cancellation", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
t.Logf("client close error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cmdCancel()
|
||||||
|
|
||||||
|
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||||
|
if err != nil {
|
||||||
|
var exitMissingErr *cryptossh.ExitMissingError
|
||||||
|
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||||
|
errors.Is(err, context.Canceled) ||
|
||||||
|
errors.As(err, &exitMissingErr)
|
||||||
|
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
@@ -78,53 +77,6 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
|||||||
assert.NotNil(t, client.client)
|
assert.NotNil(t, client.client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
|
||||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
|
||||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
|
||||||
}
|
|
||||||
|
|
||||||
server, _, client := setupTestSSHServerAndClient(t)
|
|
||||||
defer func() {
|
|
||||||
err := server.Stop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
defer func() {
|
|
||||||
err := client.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
|
||||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(output), "hello")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
|
||||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("commands with flags work", func(t *testing.T) {
|
|
||||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
|
||||||
var testCmd string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
testCmd = "echo hello | Select-String notfound"
|
|
||||||
} else {
|
|
||||||
testCmd = "echo 'hello' | grep 'notfound'"
|
|
||||||
}
|
|
||||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -154,59 +106,6 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
|
||||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
|
||||||
defer func() {
|
|
||||||
err := server.Stop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Run("connection with short timeout", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
currentUser := testutil.GetTestUsername(t)
|
|
||||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// Check for actual timeout-related errors rather than string matching
|
|
||||||
assert.True(t,
|
|
||||||
errors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
errors.Is(err, context.Canceled) ||
|
|
||||||
strings.Contains(err.Error(), "timeout"),
|
|
||||||
"Expected timeout-related error, got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("command execution cancellation", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
currentUser := testutil.GetTestUsername(t)
|
|
||||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
if err := client.Close(); err != nil {
|
|
||||||
t.Logf("client close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cmdCancel()
|
|
||||||
|
|
||||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
|
||||||
if err != nil {
|
|
||||||
var exitMissingErr *cryptossh.ExitMissingError
|
|
||||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
errors.Is(err, context.Canceled) ||
|
|
||||||
errors.As(err, &exitMissingErr)
|
|
||||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
423
client/ssh/proxy/proxy_privileged_test.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
//go:build privileged
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *mockDaemon) setJWTToken(token string) {
|
||||||
|
m.impl.jwtToken = token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHProxy_Connect(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &server.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: &server.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audiences: []string{audience},
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer := server.New(serverConfig)
|
||||||
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
// Configure SSH authorization for the test user
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
|
defer func() { _ = sshServer.Stop() }()
|
||||||
|
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
defer mockDaemon.stop()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
origStdin := os.Stdin
|
||||||
|
origStdout := os.Stdout
|
||||||
|
defer func() {
|
||||||
|
os.Stdin = origStdin
|
||||||
|
os.Stdout = origStdout
|
||||||
|
}()
|
||||||
|
|
||||||
|
stdinReader, stdinWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
os.Stdin = stdinReader
|
||||||
|
os.Stdout = stdoutWriter
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
connectErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
connectErrCh <- proxyInstance.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sshConfig := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||||
|
require.NoError(t, err, "Should connect to proxy server")
|
||||||
|
defer func() { _ = sshClientConn.Close() }()
|
||||||
|
|
||||||
|
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||||
|
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||||
|
|
||||||
|
outputCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
output, err := session.Output("echo hello-from-proxy")
|
||||||
|
outputCh <- output
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case output := <-outputCh:
|
||||||
|
err := <-errCh
|
||||||
|
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||||
|
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("Command execution timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = session.Close()
|
||||||
|
_ = sshClient.Close()
|
||||||
|
_ = clientConn.Close()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||||
|
// when forwarding commands to the backend. This is critical for tools like
|
||||||
|
// Ansible that send commands such as:
|
||||||
|
//
|
||||||
|
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||||
|
//
|
||||||
|
// The single quotes must be preserved so the backend shell receives the
|
||||||
|
// subshell expression as a single argument to -c.
|
||||||
|
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClient, cleanup := setupProxySSHClient(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||||
|
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||||
|
// the local shell strips the outer single quotes, and the SSH exec request
|
||||||
|
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||||
|
//
|
||||||
|
// The proxy must forward this string verbatim. Using session.Command()
|
||||||
|
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||||
|
// the command on the backend.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "subshell_in_double_quotes",
|
||||||
|
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||||
|
expect: "from-subshell\nouter\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "printf_with_special_chars",
|
||||||
|
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||||
|
expect: "hello world\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested_command_substitution",
|
||||||
|
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||||
|
expect: "nested\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = session.Close() }()
|
||||||
|
|
||||||
|
var stderrBuf bytes.Buffer
|
||||||
|
session.Stderr = &stderrBuf
|
||||||
|
|
||||||
|
outputCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
output, err := session.Output(tc.command)
|
||||||
|
outputCh <- output
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case output := <-outputCh:
|
||||||
|
err := <-errCh
|
||||||
|
if stderrBuf.Len() > 0 {
|
||||||
|
t.Logf("stderr: %s", stderrBuf.String())
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||||
|
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatalf("command timed out: %s", tc.command)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupProxySSHClient creates a full proxy test environment and returns
|
||||||
|
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||||
|
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &server.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: &server.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audiences: []string{audience},
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer := server.New(serverConfig)
|
||||||
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
|
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
origStdin := os.Stdin
|
||||||
|
origStdout := os.Stdout
|
||||||
|
|
||||||
|
stdinReader, stdinWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
os.Stdin = stdinReader
|
||||||
|
os.Stdout = stdoutWriter
|
||||||
|
|
||||||
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
|
||||||
|
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||||
|
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = proxyInstance.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sshConfig := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||||
|
|
||||||
|
cleanupFn := func() {
|
||||||
|
_ = client.Close()
|
||||||
|
_ = clientConn.Close()
|
||||||
|
cancel()
|
||||||
|
os.Stdin = origStdin
|
||||||
|
os.Stdout = origStdout
|
||||||
|
_ = sshServer.Stop()
|
||||||
|
mockDaemon.stop()
|
||||||
|
jwksServer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, cleanupFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, jwksJSON := generateTestJWKS(t)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if _, err := w.Write(jwksJSON); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return server, privateKey, server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
n := publicKey.N.Bytes()
|
||||||
|
e := publicKey.E
|
||||||
|
|
||||||
|
jwk := nbjwt.JSONWebKey{
|
||||||
|
Kty: "RSA",
|
||||||
|
Kid: "test-key-id",
|
||||||
|
Use: "sig",
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(n),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := nbjwt.Jwks{
|
||||||
|
Keys: []nbjwt.JSONWebKey{jwk},
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksJSON, err := json.Marshal(jwks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return privateKey, jwksJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||||
|
t.Helper()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": user,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
@@ -1,25 +1,12 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
@@ -28,11 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/server"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
|
||||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -106,331 +89,6 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHProxy_Connect(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
issuer = "https://test-issuer.example.com"
|
|
||||||
audience = "test-audience"
|
|
||||||
)
|
|
||||||
|
|
||||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
||||||
defer jwksServer.Close()
|
|
||||||
|
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverConfig := &server.Config{
|
|
||||||
HostKeyPEM: hostKey,
|
|
||||||
JWT: &server.JWTConfig{
|
|
||||||
Issuer: issuer,
|
|
||||||
Audiences: []string{audience},
|
|
||||||
KeysLocation: jwksURL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer := server.New(serverConfig)
|
|
||||||
sshServer.SetAllowRootLogin(true)
|
|
||||||
|
|
||||||
// Configure SSH authorization for the test user
|
|
||||||
testUsername := testutil.GetTestUsername(t)
|
|
||||||
testJWTUser := "test-username"
|
|
||||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authConfig := &sshauth.Config{
|
|
||||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
||||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
||||||
MachineUsers: map[string][]uint32{
|
|
||||||
testUsername: {0}, // Index 0 in AuthorizedUsers
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer.UpdateSSHAuth(authConfig)
|
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
||||||
defer func() { _ = sshServer.Stop() }()
|
|
||||||
|
|
||||||
mockDaemon := startMockDaemon(t)
|
|
||||||
defer mockDaemon.stop()
|
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
||||||
mockDaemon.setJWTToken(validToken)
|
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
|
||||||
defer func() { _ = clientConn.Close() }()
|
|
||||||
|
|
||||||
origStdin := os.Stdin
|
|
||||||
origStdout := os.Stdout
|
|
||||||
defer func() {
|
|
||||||
os.Stdin = origStdin
|
|
||||||
os.Stdout = origStdout
|
|
||||||
}()
|
|
||||||
|
|
||||||
stdinReader, stdinWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
os.Stdin = stdinReader
|
|
||||||
os.Stdout = stdoutWriter
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
connectErrCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
connectErrCh <- proxyInstance.Connect(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
sshConfig := &cryptossh.ClientConfig{
|
|
||||||
User: testutil.GetTestUsername(t),
|
|
||||||
Auth: []cryptossh.AuthMethod{},
|
|
||||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
||||||
require.NoError(t, err, "Should connect to proxy server")
|
|
||||||
defer func() { _ = sshClientConn.Close() }()
|
|
||||||
|
|
||||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
||||||
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
|
||||||
|
|
||||||
outputCh := make(chan []byte, 1)
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
output, err := session.Output("echo hello-from-proxy")
|
|
||||||
outputCh <- output
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case output := <-outputCh:
|
|
||||||
err := <-errCh
|
|
||||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
|
||||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("Command execution timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = session.Close()
|
|
||||||
_ = sshClient.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
|
||||||
// when forwarding commands to the backend. This is critical for tools like
|
|
||||||
// Ansible that send commands such as:
|
|
||||||
//
|
|
||||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
|
||||||
//
|
|
||||||
// The single quotes must be preserved so the backend shell receives the
|
|
||||||
// subshell expression as a single argument to -c.
|
|
||||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClient, cleanup := setupProxySSHClient(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
|
||||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
|
||||||
// the local shell strips the outer single quotes, and the SSH exec request
|
|
||||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
|
||||||
//
|
|
||||||
// The proxy must forward this string verbatim. Using session.Command()
|
|
||||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
|
||||||
// the command on the backend.
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
command string
|
|
||||||
expect string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "subshell_in_double_quotes",
|
|
||||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
|
||||||
expect: "from-subshell\nouter\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "printf_with_special_chars",
|
|
||||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
|
||||||
expect: "hello world\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nested_command_substitution",
|
|
||||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
|
||||||
expect: "nested\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = session.Close() }()
|
|
||||||
|
|
||||||
var stderrBuf bytes.Buffer
|
|
||||||
session.Stderr = &stderrBuf
|
|
||||||
|
|
||||||
outputCh := make(chan []byte, 1)
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
output, err := session.Output(tc.command)
|
|
||||||
outputCh <- output
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case output := <-outputCh:
|
|
||||||
err := <-errCh
|
|
||||||
if stderrBuf.Len() > 0 {
|
|
||||||
t.Logf("stderr: %s", stderrBuf.String())
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
|
||||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatalf("command timed out: %s", tc.command)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupProxySSHClient creates a full proxy test environment and returns
|
|
||||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
|
||||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
const (
|
|
||||||
issuer = "https://test-issuer.example.com"
|
|
||||||
audience = "test-audience"
|
|
||||||
)
|
|
||||||
|
|
||||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
||||||
|
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverConfig := &server.Config{
|
|
||||||
HostKeyPEM: hostKey,
|
|
||||||
JWT: &server.JWTConfig{
|
|
||||||
Issuer: issuer,
|
|
||||||
Audiences: []string{audience},
|
|
||||||
KeysLocation: jwksURL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer := server.New(serverConfig)
|
|
||||||
sshServer.SetAllowRootLogin(true)
|
|
||||||
|
|
||||||
testUsername := testutil.GetTestUsername(t)
|
|
||||||
testJWTUser := "test-username"
|
|
||||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authConfig := &sshauth.Config{
|
|
||||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
||||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
||||||
MachineUsers: map[string][]uint32{
|
|
||||||
testUsername: {0},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
sshServer.UpdateSSHAuth(authConfig)
|
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
||||||
|
|
||||||
mockDaemon := startMockDaemon(t)
|
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
||||||
mockDaemon.setJWTToken(validToken)
|
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
origStdin := os.Stdin
|
|
||||||
origStdout := os.Stdout
|
|
||||||
|
|
||||||
stdinReader, stdinWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
os.Stdin = stdinReader
|
|
||||||
os.Stdout = stdoutWriter
|
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
|
||||||
|
|
||||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
|
||||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = proxyInstance.Connect(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
sshConfig := &cryptossh.ClientConfig{
|
|
||||||
User: testutil.GetTestUsername(t),
|
|
||||||
Auth: []cryptossh.AuthMethod{},
|
|
||||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
||||||
|
|
||||||
cleanupFn := func() {
|
|
||||||
_ = client.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
cancel()
|
|
||||||
os.Stdin = origStdin
|
|
||||||
os.Stdout = origStdout
|
|
||||||
_ = sshServer.Stop()
|
|
||||||
mockDaemon.stop()
|
|
||||||
jwksServer.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, cleanupFn
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockDaemonServer struct {
|
type mockDaemonServer struct {
|
||||||
proto.UnimplementedDaemonServiceServer
|
proto.UnimplementedDaemonServiceServer
|
||||||
hostKeys map[string][]byte
|
hostKeys map[string][]byte
|
||||||
@@ -492,10 +150,6 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
|||||||
m.impl.hostKeys[addr] = pubKey
|
m.impl.hostKeys[addr] = pubKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockDaemon) setJWTToken(token string) {
|
|
||||||
m.impl.jwtToken = token
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDaemon) stop() {
|
func (m *mockDaemon) stop() {
|
||||||
if m.server != nil {
|
if m.server != nil {
|
||||||
m.server.Stop()
|
m.server.Stop()
|
||||||
@@ -508,63 +162,3 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return pubKey
|
return pubKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
|
||||||
t.Helper()
|
|
||||||
privateKey, jwksJSON := generateTestJWKS(t)
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if _, err := w.Write(jwksJSON); err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
return server, privateKey, server.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
|
||||||
t.Helper()
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
publicKey := &privateKey.PublicKey
|
|
||||||
n := publicKey.N.Bytes()
|
|
||||||
e := publicKey.E
|
|
||||||
|
|
||||||
jwk := nbjwt.JSONWebKey{
|
|
||||||
Kty: "RSA",
|
|
||||||
Kid: "test-key-id",
|
|
||||||
Use: "sig",
|
|
||||||
N: base64.RawURLEncoding.EncodeToString(n),
|
|
||||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
|
||||||
}
|
|
||||||
|
|
||||||
jwks := nbjwt.Jwks{
|
|
||||||
Keys: []nbjwt.JSONWebKey{jwk},
|
|
||||||
}
|
|
||||||
|
|
||||||
jwksJSON, err := json.Marshal(jwks)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return privateKey, jwksJSON
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
|
||||||
t.Helper()
|
|
||||||
claims := jwt.MapClaims{
|
|
||||||
"iss": issuer,
|
|
||||||
"aud": audience,
|
|
||||||
"sub": user,
|
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
|
||||||
"iat": time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
||||||
token.Header["kid"] = "test-key-id"
|
|
||||||
|
|
||||||
tokenString, err := token.SignedString(privateKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return tokenString
|
|
||||||
}
|
|
||||||
|
|||||||
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
66
client/ssh/server/executor_unix_privileged_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build unix && privileged
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||||
|
pd := NewPrivilegeDropper()
|
||||||
|
|
||||||
|
config := ExecutorConfig{
|
||||||
|
UID: 1000,
|
||||||
|
GID: 1000,
|
||||||
|
Groups: []uint32{1000, 1001},
|
||||||
|
WorkingDir: "/home/testuser",
|
||||||
|
Shell: "/bin/bash",
|
||||||
|
Command: "ls -la",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, cmd)
|
||||||
|
|
||||||
|
// Verify the command is calling netbird ssh exec
|
||||||
|
assert.Contains(t, cmd.Args, "ssh")
|
||||||
|
assert.Contains(t, cmd.Args, "exec")
|
||||||
|
assert.Contains(t, cmd.Args, "--uid")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "--gid")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "--groups")
|
||||||
|
assert.Contains(t, cmd.Args, "1000")
|
||||||
|
assert.Contains(t, cmd.Args, "1001")
|
||||||
|
assert.Contains(t, cmd.Args, "--working-dir")
|
||||||
|
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||||
|
assert.Contains(t, cmd.Args, "--shell")
|
||||||
|
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||||
|
assert.Contains(t, cmd.Args, "--cmd")
|
||||||
|
assert.Contains(t, cmd.Args, "ls -la")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||||
|
pd := NewPrivilegeDropper()
|
||||||
|
|
||||||
|
config := ExecutorConfig{
|
||||||
|
UID: 1000,
|
||||||
|
GID: 1000,
|
||||||
|
Groups: []uint32{1000},
|
||||||
|
WorkingDir: "/home/testuser",
|
||||||
|
Shell: "/bin/bash",
|
||||||
|
Command: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, cmd)
|
||||||
|
|
||||||
|
// Verify no command mode (command is empty so no --cmd flag)
|
||||||
|
assert.NotContains(t, cmd.Args, "--cmd")
|
||||||
|
assert.NotContains(t, cmd.Args, "--interactive")
|
||||||
|
}
|
||||||
@@ -73,61 +73,6 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
|
||||||
pd := NewPrivilegeDropper()
|
|
||||||
|
|
||||||
config := ExecutorConfig{
|
|
||||||
UID: 1000,
|
|
||||||
GID: 1000,
|
|
||||||
Groups: []uint32{1000, 1001},
|
|
||||||
WorkingDir: "/home/testuser",
|
|
||||||
Shell: "/bin/bash",
|
|
||||||
Command: "ls -la",
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, cmd)
|
|
||||||
|
|
||||||
// Verify the command is calling netbird ssh exec
|
|
||||||
assert.Contains(t, cmd.Args, "ssh")
|
|
||||||
assert.Contains(t, cmd.Args, "exec")
|
|
||||||
assert.Contains(t, cmd.Args, "--uid")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "--gid")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "--groups")
|
|
||||||
assert.Contains(t, cmd.Args, "1000")
|
|
||||||
assert.Contains(t, cmd.Args, "1001")
|
|
||||||
assert.Contains(t, cmd.Args, "--working-dir")
|
|
||||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
|
||||||
assert.Contains(t, cmd.Args, "--shell")
|
|
||||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
|
||||||
assert.Contains(t, cmd.Args, "--cmd")
|
|
||||||
assert.Contains(t, cmd.Args, "ls -la")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
|
||||||
pd := NewPrivilegeDropper()
|
|
||||||
|
|
||||||
config := ExecutorConfig{
|
|
||||||
UID: 1000,
|
|
||||||
GID: 1000,
|
|
||||||
Groups: []uint32{1000},
|
|
||||||
WorkingDir: "/home/testuser",
|
|
||||||
Shell: "/bin/bash",
|
|
||||||
Command: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, cmd)
|
|
||||||
|
|
||||||
// Verify no command mode (command is empty so no --cmd flag)
|
|
||||||
assert.NotContains(t, cmd.Args, "--cmd")
|
|
||||||
assert.NotContains(t, cmd.Args, "--interactive")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||||
// This test requires root privileges and will be skipped if not running as root
|
// This test requires root privileges and will be skipped if not running as root
|
||||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
@@ -121,6 +124,23 @@ func (i *Info) SetFlags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeAddresses drops network addresses whose IP matches any of the given
|
||||||
|
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
|
||||||
|
// address, which otherwise churns the meta as the interface comes and goes.
|
||||||
|
func (i *Info) removeAddresses(ips ...netip.Addr) {
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filtered := i.NetworkAddresses[:0]
|
||||||
|
for _, addr := range i.NetworkAddresses {
|
||||||
|
if slices.Contains(ips, addr.NetIP.Addr()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, addr)
|
||||||
|
}
|
||||||
|
i.NetworkAddresses = filtered
|
||||||
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
@@ -147,14 +167,16 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
// excludeIPs are dropped from the reported network addresses (e.g. our own
|
||||||
|
// WireGuard overlay address, which otherwise churns the peer meta).
|
||||||
|
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
|
||||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||||
processCheckPaths := make([]string, 0)
|
processCheckPaths := make([]string, 0)
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
files, err := checkFileAndProcess(processCheckPaths)
|
files, err := checkFileAndProcess(ctx, processCheckPaths)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -162,7 +184,48 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
|||||||
|
|
||||||
info := GetInfo(ctx)
|
info := GetInfo(ctx)
|
||||||
info.Files = files
|
info.Files = files
|
||||||
|
info.removeAddresses(excludeIPs...)
|
||||||
|
|
||||||
log.Debugf("all system information gathered successfully")
|
log.Debugf("all system information gathered successfully")
|
||||||
return info, nil
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
|
||||||
|
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
|
||||||
|
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
|
||||||
|
// does not return within timeout the caller gets (nil, false) and should proceed with
|
||||||
|
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
|
||||||
|
//
|
||||||
|
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
|
||||||
|
// returns, so it does not leak beyond the duration of that call.
|
||||||
|
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
infoCh := make(chan *Info, 1)
|
||||||
|
go func() {
|
||||||
|
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
|
info = GetInfo(ctx)
|
||||||
|
info.removeAddresses(excludeIPs...)
|
||||||
|
}
|
||||||
|
infoCh <- info
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case info := <-infoCh:
|
||||||
|
return info, true
|
||||||
|
case <-ctx.Done():
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
log.Warnf("gathering system info with checks timed out after %s", timeout)
|
||||||
|
} else {
|
||||||
|
// Parent context canceled (e.g. shutdown), not a timeout.
|
||||||
|
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||||
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||||
swVersion = []byte(release)
|
swVersion = []byte(release)
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
@@ -34,6 +36,20 @@ func Test_CustomHostname(t *testing.T) {
|
|||||||
assert.Equal(t, want, got.Hostname)
|
assert.Equal(t, want, got.Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
|
||||||
|
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
|
||||||
|
assert.True(t, ok, "expected gathering to complete within the timeout")
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
|
||||||
|
// A 1ns budget expires before the (real) system-info gathering can finish, so the
|
||||||
|
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
|
||||||
|
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
|
||||||
|
assert.False(t, ok, "expected timeout to be reported")
|
||||||
|
assert.Nil(t, info)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_NetAddresses(t *testing.T) {
|
func Test_NetAddresses(t *testing.T) {
|
||||||
addr, err := networkAddresses()
|
addr, err := networkAddresses()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -43,3 +59,42 @@ func Test_NetAddresses(t *testing.T) {
|
|||||||
t.Errorf("no network addresses found")
|
t.Errorf("no network addresses found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInfo_RemoveAddresses(t *testing.T) {
|
||||||
|
addr := func(cidr string) NetworkAddress {
|
||||||
|
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &Info{
|
||||||
|
NetworkAddresses: []NetworkAddress{
|
||||||
|
addr("192.168.1.7/24"),
|
||||||
|
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
|
||||||
|
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
|
||||||
|
addr("fd00:1234::1/64"), // overlay v6
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
|
||||||
|
info.removeAddresses(
|
||||||
|
netip.MustParseAddr("100.76.70.97"),
|
||||||
|
netip.MustParseAddr("fd00:1234::1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
|
||||||
|
if len(info.NetworkAddresses) != len(want) {
|
||||||
|
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
|
||||||
|
}
|
||||||
|
for i, w := range want {
|
||||||
|
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
|
||||||
|
t.Errorf("address[%d] = %s, want %s", i, got, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
|
||||||
|
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
|
||||||
|
info.removeAddresses()
|
||||||
|
if len(info.NetworkAddresses) != 1 {
|
||||||
|
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return NetworkAddress{}, false
|
return NetworkAddress{}, false
|
||||||
}
|
}
|
||||||
if ipNet.IP.IsLoopback() {
|
// Skip link-local and multicast: they carry no routable peer info and the
|
||||||
|
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
|
||||||
|
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||||
return NetworkAddress{}, false
|
return NetworkAddress{}, false
|
||||||
}
|
}
|
||||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||||
|
|||||||
45
client/system/network_addr_test.go
Normal file
45
client/system/network_addr_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
|
||||||
|
t.Helper()
|
||||||
|
ip, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse %q: %v", cidr, err)
|
||||||
|
}
|
||||||
|
ipNet.IP = ip
|
||||||
|
return ipNet
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToNetworkAddress_Filtering(t *testing.T) {
|
||||||
|
const mac = "c8:4b:d6:b6:04:ac"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cidr string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"ipv4 global", "10.65.16.181/23", true},
|
||||||
|
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
|
||||||
|
{"ipv4 loopback", "127.0.0.1/8", false},
|
||||||
|
{"ipv6 loopback", "::1/128", false},
|
||||||
|
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
|
||||||
|
{"ipv4 link-local", "169.254.1.2/16", false},
|
||||||
|
{"ipv6 multicast", "ff02::1/128", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,24 +3,30 @@
|
|||||||
package system
|
package system
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/shirou/gopsutil/v3/process"
|
"github.com/shirou/gopsutil/v3/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getRunningProcesses returns a list of running process paths.
|
// getRunningProcesses returns a list of running process paths. The context bounds the work:
|
||||||
func getRunningProcesses() ([]string, error) {
|
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
|
||||||
processIDs, err := process.Pids()
|
// can, so a stuck enumeration cannot run unbounded.
|
||||||
|
func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||||
|
processIDs, err := process.PidsWithContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
processMap := make(map[string]bool)
|
processMap := make(map[string]bool)
|
||||||
for _, pID := range processIDs {
|
for _, pID := range processIDs {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
p := &process.Process{Pid: pID}
|
p := &process.Process{Pid: pID}
|
||||||
|
|
||||||
path, _ := p.Exe()
|
path, _ := p.ExeWithContext(ctx)
|
||||||
if path != "" {
|
if path != "" {
|
||||||
processMap[path] = false
|
processMap[path] = false
|
||||||
}
|
}
|
||||||
@@ -35,18 +41,21 @@ func getRunningProcesses() ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
|
||||||
files := make([]File, len(paths))
|
files := make([]File, len(paths))
|
||||||
if len(paths) == 0 {
|
if len(paths) == 0 {
|
||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
runningProcesses, err := getRunningProcesses()
|
runningProcesses, err := getRunningProcesses(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, path := range paths {
|
for i, path := range paths {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
file := File{Path: path}
|
file := File{Path: path}
|
||||||
|
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user