Compare commits

..

4 Commits

Author SHA1 Message Date
bcmmbaga
4edb944051 Merge branch 'main' into refactor/mgmt-bootstrap 2026-06-16 17:17:03 +03:00
bcmmbaga
37ce40ede6 build client config in sync response 2026-06-05 23:53:48 +03:00
bcmmbaga
318014552f Merge branch 'main' into refactor/mgmt-bootstrap 2026-06-05 22:02:25 +03:00
bcmmbaga
78ed62535e Clean up middleware and move auth middlewa into module 2026-05-26 20:45:57 +03:00
242 changed files with 5401 additions and 13247 deletions

View File

@@ -20,7 +20,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
@@ -59,12 +59,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Set up Go - name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: true cache: true

View File

@@ -15,7 +15,7 @@ jobs:
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1 - uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1

View File

@@ -16,18 +16,18 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }} key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -45,10 +45,10 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
@@ -28,7 +28,7 @@ jobs:
id: test id: test
env: env:
GO_VERSION: ${{ steps.goversion.outputs.version }} GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8 uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with: with:
usesh: true usesh: true
copyback: false copyback: false
@@ -48,14 +48,14 @@ jobs:
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -tags privileged -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use` # NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/... time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -tags privileged -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -tags privileged -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...
time go test -tags privileged -timeout 1m -failfast ./formatter/... time go test -timeout 1m -failfast ./formatter/...
time go test -tags privileged -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./client/iface/...
time go test -tags privileged -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./route/...
time go test -tags privileged -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./sharedsock/...
time go test -tags privileged -timeout 1m -failfast ./util/... time go test -timeout 1m -failfast ./util/...
time go test -tags privileged -timeout 1m -failfast ./version/... time go test -timeout 1m -failfast ./version/...

View File

@@ -18,7 +18,7 @@ jobs:
management: ${{ steps.filter.outputs.management }} management: ${{ steps.filter.outputs.management }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
@@ -30,7 +30,7 @@ jobs:
- 'management/**' - 'management/**'
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -41,7 +41,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache id: cache
with: with:
path: | path: |
@@ -119,12 +119,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -135,7 +135,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -158,11 +158,11 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -175,12 +175,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -192,7 +192,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache-restore id: cache-restore
with: with:
path: | path: |
@@ -229,7 +229,7 @@ jobs:
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged) go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
' '
test_relay: test_relay:
@@ -246,12 +246,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -266,7 +266,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -290,7 +290,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -306,12 +306,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -325,7 +325,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -347,7 +347,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -363,12 +363,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -383,7 +383,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -407,7 +407,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -424,12 +424,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -440,7 +440,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -484,7 +484,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -529,12 +529,12 @@ jobs:
prom/prometheus prom/prometheus
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -545,7 +545,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -579,11 +579,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
env:
GIT_BRANCH: ${{ github.ref_name }}
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -624,12 +623,12 @@ jobs:
prom/prometheus prom/prometheus
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -640,7 +639,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -674,13 +673,12 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \ go test -tags=benchmark \
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/... -timeout 20m ./management/server/http/...
env:
GIT_BRANCH: ${{ github.ref_name }}
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -694,12 +692,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -710,7 +708,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -736,7 +734,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird

View File

@@ -18,12 +18,12 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
id: go id: go
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
@@ -35,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
${{ env.cache }} ${{ env.cache }}
@@ -68,7 +68,7 @@ jobs:
run: | run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } $packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe" $goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1" $cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test - name: test

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: codespell - name: codespell
@@ -40,7 +40,7 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Check for duplicate constants - name: Check for duplicate constants
@@ -48,7 +48,7 @@ jobs:
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false

View File

@@ -16,11 +16,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Setup Android SDK - name: Setup Android SDK
@@ -28,13 +28,13 @@ jobs:
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520 uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -54,11 +54,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: install gomobile - name: install gomobile

View File

@@ -9,13 +9,10 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.6" SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.16.0" GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
flags: ""
SKIP_PUBLISH: "true"
SKIP_DOCKER_PUSH: "false"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -27,7 +24,7 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
@@ -64,7 +61,7 @@ jobs:
if: steps.check_diff.outputs.diff_exists == 'true' if: steps.check_diff.outputs.diff_exists == 'true'
env: env:
GO_VERSION: ${{ steps.goversion.outputs.version }} GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8 uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with: with:
usesh: true usesh: true
copyback: false copyback: false
@@ -133,9 +130,11 @@ jobs:
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }} windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }} macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }} ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env:
flags: ""
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
@@ -144,34 +143,15 @@ jobs:
id: semver_parser id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2 uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Set snapshot flag - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
fi
- name: Set up Go - name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -186,9 +166,9 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0 uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0 uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Login to Docker hub - name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
@@ -221,7 +201,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2 uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }} args: release --clean ${{ env.flags }}
@@ -232,8 +212,6 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
- name: Verify RPM signatures - name: Verify RPM signatures
run: | run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
@@ -347,7 +325,7 @@ jobs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }} release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
@@ -356,30 +334,16 @@ jobs:
id: semver_parser id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2 uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Set snapshot flag - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
- name: Set up Go - name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -420,7 +384,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2 uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -431,7 +395,6 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }} GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }} NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
- name: Verify RPM signatures - name: Verify RPM signatures
run: | run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c ' docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
@@ -464,17 +427,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
- name: Set up Go - name: Set up Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -488,7 +451,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2 uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -522,7 +485,7 @@ jobs:
downloadPath: '${{ github.workspace }}\temp' downloadPath: '${{ github.workspace }}\temp'
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
@@ -534,13 +497,13 @@ jobs:
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts - name: Download release artifacts
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with: with:
name: release name: release
path: release path: release
- name: Download UI release artifacts - name: Download UI release artifacts
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with: with:
name: release-ui name: release-ui
path: release-ui path: release-ui

View File

@@ -68,17 +68,17 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0 uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -207,7 +207,7 @@ jobs:
- name: Build management docker image - name: Build management docker image
working-directory: management working-directory: management
run: | run: |
docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. . docker build -t netbirdio/management:latest .
- name: Build signal binary - name: Build signal binary
working-directory: signal working-directory: signal
@@ -216,7 +216,7 @@ jobs:
- name: Build signal docker image - name: Build signal docker image
working-directory: signal working-directory: signal
run: | run: |
docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. . docker build -t netbirdio/signal:latest .
- name: Build relay binary - name: Build relay binary
working-directory: relay working-directory: relay
@@ -225,7 +225,7 @@ jobs:
- name: Build relay docker image - name: Build relay docker image
working-directory: relay working-directory: relay
run: | run: |
docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. . docker build -t netbirdio/relay:latest .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
@@ -256,7 +256,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false

View File

@@ -19,11 +19,11 @@ jobs:
GOARCH: wasm GOARCH: wasm
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Install dependencies - name: Install dependencies
@@ -44,11 +44,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0 uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Build Wasm client - name: Build Wasm client

View File

@@ -1,7 +1,5 @@
version: 2 version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
project_name: netbird project_name: netbird
builds: builds:
- id: netbird-wasm - id: netbird-wasm
@@ -76,8 +74,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -92,8 +88,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -108,8 +102,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -130,8 +122,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -146,8 +136,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -162,8 +150,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}} - -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -184,8 +170,6 @@ builds:
- amd64 - amd64
- arm64 - arm64
- arm - arm
goarm:
- 7
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
@@ -238,192 +222,670 @@ nfpms:
rpm: rpm:
signature: signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers_v2: dockers:
- id: netbird - image_templates:
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - netbirdio/netbird:{{ .Version }}-amd64
ids: - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- netbird ids:
images: - netbird
- netbirdio/netbird goarch: amd64
- ghcr.io/netbirdio/netbird use: buildx
tags: dockerfile: client/Dockerfile
- "{{ .Version }}" extra_files:
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - client/netbird-entrypoint.sh
dockerfile: client/Dockerfile build_flag_templates:
extra_files: - "--platform=linux/amd64"
- client/netbird-entrypoint.sh - "--label=org.opencontainers.image.created={{.Date}}"
platforms: - "--label=org.opencontainers.image.title={{.ProjectName}}"
- linux/amd64 - "--label=org.opencontainers.image.version={{.Version}}"
- linux/arm64 - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- linux/arm/6 - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
annotations: - "--label=maintainer=dev@netbird.io"
"org.opencontainers.image.created": "{{.Date}}" - image_templates:
"org.opencontainers.image.title": "{{.ProjectName}}" - netbirdio/netbird:{{ .Version }}-arm64v8
"org.opencontainers.image.version": "{{.Version}}" - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
"org.opencontainers.image.revision": "{{.FullCommit}}" ids:
"org.opencontainers.image.source": "{{.GitURL}}" - netbird
"maintainer": "dev@netbird.io" goarch: arm64
- id: netbird-rootless use: buildx
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" dockerfile: client/Dockerfile
ids: extra_files:
- netbird - client/netbird-entrypoint.sh
images: build_flag_templates:
- netbirdio/netbird - "--platform=linux/arm64"
- ghcr.io/netbirdio/netbird - "--label=org.opencontainers.image.created={{.Date}}"
tags: - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "v{{ .Version }}-rootless" - "--label=org.opencontainers.image.version={{.Version}}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
dockerfile: client/Dockerfile-rootless - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
extra_files: - "--label=maintainer=dev@netbird.io"
- client/netbird-entrypoint.sh - image_templates:
platforms: - netbirdio/netbird:{{ .Version }}-arm
- linux/amd64 - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- linux/arm64 ids:
- linux/arm/6 - netbird
annotations: goarch: arm
"org.opencontainers.image.created": "{{.Date}}" goarm: 6
"org.opencontainers.image.title": "{{.ProjectName}}" use: buildx
"org.opencontainers.image.version": "{{.Version}}" dockerfile: client/Dockerfile
"org.opencontainers.image.revision": "{{.FullCommit}}" extra_files:
"org.opencontainers.image.source": "{{.GitURL}}" - client/netbird-entrypoint.sh
"maintainer": "dev@netbird.io" build_flag_templates:
- id: relay - "--platform=linux/arm"
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - "--label=org.opencontainers.image.created={{.Date}}"
ids: - "--label=org.opencontainers.image.title={{.ProjectName}}"
- netbird-relay - "--label=org.opencontainers.image.version={{.Version}}"
images: - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- netbirdio/relay - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- ghcr.io/netbirdio/relay - "--label=maintainer=dev@netbird.io"
tags:
- "{{ .Version }}" - image_templates:
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - netbirdio/netbird:{{ .Version }}-rootless-amd64
dockerfile: relay/Dockerfile - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
platforms: ids:
- linux/amd64 - netbird
- linux/arm64 goarch: amd64
- linux/arm use: buildx
annotations: dockerfile: client/Dockerfile-rootless
"org.opencontainers.image.created": "{{.Date}}" extra_files:
"org.opencontainers.image.title": "{{.ProjectName}}" - client/netbird-entrypoint.sh
"org.opencontainers.image.version": "{{.Version}}" build_flag_templates:
"org.opencontainers.image.revision": "{{.FullCommit}}" - "--platform=linux/amd64"
"org.opencontainers.image.source": "{{.GitURL}}" - "--label=org.opencontainers.image.created={{.Date}}"
"maintainer": "dev@netbird.io" - "--label=org.opencontainers.image.title={{.ProjectName}}"
- id: signal - "--label=org.opencontainers.image.version={{.Version}}"
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
ids: - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbird-signal - "--label=maintainer=dev@netbird.io"
images: - image_templates:
- netbirdio/signal - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/signal - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
tags: ids:
- "{{ .Version }}" - netbird
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" goarch: arm64
dockerfile: signal/Dockerfile use: buildx
platforms: dockerfile: client/Dockerfile-rootless
- linux/amd64 extra_files:
- linux/arm64 - client/netbird-entrypoint.sh
- linux/arm build_flag_templates:
annotations: - "--platform=linux/arm64"
"org.opencontainers.image.created": "{{.Date}}" - "--label=org.opencontainers.image.created={{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
"maintainer": "dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- id: management - image_templates:
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - netbirdio/netbird:{{ .Version }}-rootless-arm
ids: - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- netbird-mgmt ids:
images: - netbird
- netbirdio/management goarch: arm
- ghcr.io/netbirdio/management goarm: 6
tags: use: buildx
- "{{ .Version }}" dockerfile: client/Dockerfile-rootless
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" extra_files:
dockerfile: management/Dockerfile - client/netbird-entrypoint.sh
platforms: build_flag_templates:
- linux/amd64 - "--platform=linux/arm"
- linux/arm64 - "--label=org.opencontainers.image.created={{.Date}}"
- linux/arm - "--label=org.opencontainers.image.title={{.ProjectName}}"
annotations: - "--label=org.opencontainers.image.version={{.Version}}"
"org.opencontainers.image.created": "{{.Date}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
"org.opencontainers.image.title": "{{.ProjectName}}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}" - "--label=maintainer=dev@netbird.io"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}" - image_templates:
"maintainer": "dev@netbird.io" - netbirdio/relay:{{ .Version }}-amd64
- id: upload - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" ids:
ids: - netbird-relay
- netbird-upload goarch: amd64
images: use: buildx
- netbirdio/upload dockerfile: relay/Dockerfile
- ghcr.io/netbirdio/upload build_flag_templates:
tags: - "--platform=linux/amd64"
- "{{ .Version }}" - "--label=org.opencontainers.image.created={{.Date}}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
dockerfile: upload-server/Dockerfile - "--label=org.opencontainers.image.version={{.Version}}"
platforms: - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- linux/amd64 - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- linux/arm64 - "--label=maintainer=dev@netbird.io"
- linux/arm - image_templates:
annotations: - netbirdio/relay:{{ .Version }}-arm64v8
"org.opencontainers.image.created": "{{.Date}}" - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
"org.opencontainers.image.title": "{{.ProjectName}}" ids:
"org.opencontainers.image.version": "{{.Version}}" - netbird-relay
"org.opencontainers.image.revision": "{{.FullCommit}}" goarch: arm64
"org.opencontainers.image.source": "{{.GitURL}}" use: buildx
"maintainer": "dev@netbird.io" dockerfile: relay/Dockerfile
- id: netbird-server build_flag_templates:
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - "--platform=linux/arm64"
ids: - "--label=org.opencontainers.image.created={{.Date}}"
- netbird-server - "--label=org.opencontainers.image.title={{.ProjectName}}"
images: - "--label=org.opencontainers.image.version={{.Version}}"
- netbirdio/netbird-server - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- ghcr.io/netbirdio/netbird-server - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
tags: - "--label=maintainer=dev@netbird.io"
- "{{ .Version }}" - image_templates:
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - netbirdio/relay:{{ .Version }}-arm
dockerfile: combined/Dockerfile - ghcr.io/netbirdio/relay:{{ .Version }}-arm
platforms: ids:
- linux/amd64 - netbird-relay
- linux/arm64 goarch: arm
- linux/arm goarm: 6
annotations: use: buildx
"org.opencontainers.image.created": "{{.Date}}" dockerfile: relay/Dockerfile
"org.opencontainers.image.title": "{{.ProjectName}}" build_flag_templates:
"org.opencontainers.image.version": "{{.Version}}" - "--platform=linux/arm"
"org.opencontainers.image.revision": "{{.FullCommit}}" - "--label=org.opencontainers.image.created={{.Date}}"
"org.opencontainers.image.source": "{{.GitURL}}" - "--label=org.opencontainers.image.title={{.ProjectName}}"
"maintainer": "dev@netbird.io" - "--label=org.opencontainers.image.version={{.Version}}"
- id: netbird-proxy - "--label=org.opencontainers.image.revision={{.FullCommit}}"
disable: "{{ .Env.SKIP_DOCKER_PUSH }}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
ids: - "--label=maintainer=dev@netbird.io"
- netbird-proxy - image_templates:
images: - netbirdio/signal:{{ .Version }}-amd64
- netbirdio/reverse-proxy - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy ids:
tags: - netbird-signal
- "{{ .Version }}" goarch: amd64
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" use: buildx
dockerfile: proxy/Dockerfile dockerfile: signal/Dockerfile
platforms: build_flag_templates:
- linux/amd64 - "--platform=linux/amd64"
- linux/arm64 - "--label=org.opencontainers.image.created={{.Date}}"
- linux/arm - "--label=org.opencontainers.image.title={{.ProjectName}}"
annotations: - "--label=org.opencontainers.image.version={{.Version}}"
"org.opencontainers.image.created": "{{.Date}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
"org.opencontainers.image.title": "{{.ProjectName}}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}" - "--label=maintainer=dev@netbird.io"
"org.opencontainers.image.revision": "{{.FullCommit}}" - image_templates:
"org.opencontainers.image.source": "{{.GitURL}}" - netbirdio/signal:{{ .Version }}-arm64v8
"maintainer": "dev@netbird.io" - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
goarm: 6
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:latest
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
repository: repository:
owner: netbirdio owner: netbirdio
name: homebrew-tap name: homebrew-tap
@@ -440,7 +902,6 @@ brews:
uploads: uploads:
- name: debian - name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids: ids:
- netbird_deb - netbird_deb
mode: archive mode: archive
@@ -449,7 +910,6 @@ uploads:
method: PUT method: PUT
- name: yum - name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids: ids:
- netbird_rpm - netbird_rpm
mode: archive mode: archive
@@ -462,13 +922,9 @@ checksum:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh - glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh
release: release:
extra_files: extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh - glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh

View File

@@ -1,6 +1,5 @@
version: 2 version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui - id: netbird-ui
@@ -102,7 +101,6 @@ nfpms:
uploads: uploads:
- name: debian - name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids: ids:
- netbird_ui_deb - netbird_ui_deb
mode: archive mode: archive
@@ -111,7 +109,6 @@ uploads:
method: PUT method: PUT
- name: yum - name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids: ids:
- netbird_ui_rpm - netbird_ui_rpm
mode: archive mode: archive

View File

@@ -1,4 +1,4 @@
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged .PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed # Install golangci-lint locally if needed
@@ -25,15 +25,3 @@ setup-hooks:
@git config core.hooksPath .githooks @git config core.hooksPath .githooks
@chmod +x .githooks/pre-push @chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'" @echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
# Runs as a normal user with no sudo and leaves host networking untouched.
test-unit:
@go test -tags devcert -timeout 10m ./...
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
# Narrow the run with env vars, e.g.:
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
test-privileged:
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...

View File

@@ -33,15 +33,10 @@
<br/> <br/>
<br/> <br/>
<strong> <strong>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a> 🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong> </strong>
</p> </p>
> ### 🤖 NetBird Agent Network (Beta)
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
> read the docs at **[netbird.ai](https://netbird.ai)**.
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.** **NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.

View File

@@ -1,39 +0,0 @@
# NetBird Agent Network
Agent Network is NetBird's access control layer for AI agents and the people who run
them. It gives every agent a real identity, tied to your identity provider (IdP), and
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
policy, with no API keys to leak.
> **Beta.** Agent Network is open source and can be self-hosted on your own
> infrastructure.
## How it works
Agent Network is built on two existing NetBird capabilities:
- **Overlay network** — the encrypted WireGuard mesh between peers.
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
key server-side, forwards to the API or gateway, and records usage.
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
resources (databases, internal APIs, self-hosted models) are reached directly over
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
## Where the code lives
There is no separate "agent-network" service — it reuses the reverse-proxy and management
components:
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
and runs the per-request middleware pipeline.
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
— the management-side control plane: providers, policies, guardrails, limits, routing,
and usage/access logs.
## Documentation
Full documentation, architecture, and quickstart:
**https://docs.netbird.io/agent-network**

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.24 FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \
@@ -21,7 +21,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -4,7 +4,7 @@
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.24 FROM alpine:3.22.0
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \
@@ -27,7 +27,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -23,7 +24,6 @@ const (
// Profile represents a profile for gomobile // Profile represents a profile for gomobile
type Profile struct { type Profile struct {
ID string
Name string Name string
IsActive bool IsActive bool
} }
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
├── state.json ← Default profile state ├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username) ├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles └── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Legacy work profile config ├── work.json ← Work profile config
├── work.state.json ← Legacy work profile state ├── work.state.json ← Work profile state
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config ├── personal.json ← Personal profile config
── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state ── personal.state.json ← Personal profile state
*/ */
// ProfileManager manages profiles for Android // ProfileManager manages profiles for Android
@@ -99,7 +99,6 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile var profiles []*Profile
for _, p := range internalProfiles { for _, p := range internalProfiles {
profiles = append(profiles, &Profile{ profiles = append(profiles, &Profile{
ID: p.ID.String(),
Name: p.Name, Name: p.Name,
IsActive: p.IsActive, IsActive: p.IsActive,
}) })
@@ -109,65 +108,55 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
} }
// GetActiveProfile returns the currently active profile name // GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) { func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles // Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json // ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState() activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get active profile: %w", err) return "", fmt.Errorf("failed to get active profile: %w", err)
} }
return activeState.Name, nil
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
} }
// SwitchProfile switches to a different profile // SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(id string) error { func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles // Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json // ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{ err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(id), Name: profileName,
Username: androidUsername, Username: androidUsername,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to switch profile: %w", err) return fmt.Errorf("failed to switch profile: %w", err)
} }
log.Infof("switched to profile: %s", id) log.Infof("switched to profile: %s", profileName)
return nil return nil
} }
// AddProfile creates a new profile // AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error { func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory) // Use ServiceManager (creates profile in profiles/ directory)
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername) if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
if err != nil {
return fmt.Errorf("failed to add profile: %w", err) return fmt.Errorf("failed to add profile: %w", err)
} }
log.Infof("created new profile: %s", profile.ID) log.Infof("created new profile: %s", profileName)
return nil return nil
} }
// LogoutProfile logs out from a profile (clears authentication) // LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(id string) error { func (pm *ProfileManager) LogoutProfile(profileName string) error {
configPath, err := pm.getProfileConfigPath(id) profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil { if err != nil {
return err return err
} }
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists // Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) { if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", id) return fmt.Errorf("profile '%s' does not exist", profileName)
} }
// Read current config using internal profilemanager // Read current config using internal profilemanager
@@ -185,57 +174,53 @@ func (pm *ProfileManager) LogoutProfile(id string) error {
return fmt.Errorf("failed to save config: %w", err) return fmt.Errorf("failed to save config: %w", err)
} }
log.Infof("logged out from profile: %s", id) log.Infof("logged out from profile: %s", profileName)
return nil return nil
} }
// RemoveProfile deletes a profile // RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(id string) error { func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory) // Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil { if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err) return fmt.Errorf("failed to remove profile: %w", err)
} }
log.Infof("removed profile: %s", id) log.Infof("removed profile: %s", profileName)
return nil return nil
} }
// getProfileConfigPath returns the config file path for a profile // getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile) // This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) { func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) { if profileName == "" || profileName == profilemanager.DefaultProfileName {
return "", fmt.Errorf("id %q is not valid", id)
}
if id == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json // Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/ // Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil return filepath.Join(pm.configDir, defaultConfigFilename), nil
} }
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir) profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".json"), nil return filepath.Join(profilesDir, profileName+".json"), nil
} }
// GetConfigPath returns the config file path for a given profile id // GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile() // Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(id string) (string, error) { func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(id) return pm.getProfileConfigPath(profileName)
} }
// GetStateFilePath returns the state file path for a given profile // GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile() // Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) { func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName { if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil return filepath.Join(pm.configDir, "state.json"), nil
} }
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) { profileName = sanitizeProfileName(profileName)
return "", fmt.Errorf("id %q is not valid", id)
}
profilesDir := filepath.Join(pm.configDir, profilesSubdir) profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".state.json"), nil return filepath.Join(profilesDir, profileName+".state.json"), nil
} }
// GetActiveConfigPath returns the config file path for the currently active profile // GetActiveConfigPath returns the config file path for the currently active profile
@@ -245,7 +230,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err) return "", fmt.Errorf("failed to get active profile: %w", err)
} }
return pm.GetConfigPath(activeProfile.ID) return pm.GetConfigPath(activeProfile)
} }
// GetActiveStateFilePath returns the state file path for the currently active profile // GetActiveStateFilePath returns the state file path for the currently active profile
@@ -255,5 +240,18 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err) return "", fmt.Errorf("failed to get active profile: %w", err)
} }
return pm.GetStateFilePath(activeProfile.ID) return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
} }

View File

@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{ resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: string(activeProf.ID), ProfileName: activeProf.Name,
Username: currUser.Username, Username: currUser.Username,
}) })
if err != nil { if err != nil {

View File

@@ -96,19 +96,17 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList() dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
} }
handle := activeProf.ID.String()
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
ProfileName: &handle, ProfileName: &activeProf.Name,
Username: &username, Username: &username,
} }
profileState, err := pm.GetProfileState(activeProf.ID) profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil { if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err) log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" { } else if profileState.Email != "" {
@@ -172,13 +170,14 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil return activeProf, nil
} }
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error { func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username) err := switchProfile(context.Background(), profileName, username)
if err != nil { if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err) return fmt.Errorf("switch profile on daemon: %v", err)
} }
if err := pm.SwitchProfile(resolvedID); err != nil { err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err) return fmt.Errorf("switch profile: %v", err)
} }
@@ -206,15 +205,11 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil return nil
} }
// switchProfile asks the daemon to switch to the profile identified by func switchProfile(ctx context.Context, profileName string, username string) error {
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint //nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)
} }
@@ -222,15 +217,15 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{ _, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle, ProfileName: &profileName,
Username: &username, Username: &username,
}) })
if err != nil { if err != nil {
return "", fmt.Errorf("switch profile failed: %w", err) return fmt.Errorf("switch profile failed: %v", err)
} }
return profilemanager.ID(resp.Id), nil return nil
} }
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error { func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -254,7 +249,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err) return fmt.Errorf("read config file %s: %v", configFilePath, err)
} }
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID) err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -282,7 +277,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil return nil
} }
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error { func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to create auth client: %v", err) return fmt.Errorf("failed to create auth client: %v", err)
@@ -296,7 +291,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := "" jwtToken := ""
if setupKey == "" && needsLogin { if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID) tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
@@ -311,10 +306,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := "" hint := ""
pm := profilemanager.NewProfileManager() pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileID) profileState, err := pm.GetProfileState(profileName)
if err != nil { if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err) log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" { } else if profileState.Email != "" {

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{} sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{ err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "default", Name: "default",
Username: currUser.Username, Username: currUser.Username,
}) })
if err != nil { if err != nil {

View File

@@ -2,16 +2,11 @@ package cmd
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os/user" "os/user"
"strings"
"text/tabwriter"
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -19,8 +14,6 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
var profileListShowID bool
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "Manage NetBird client profiles", Short: "Manage NetBird client profiles",
@@ -38,40 +31,27 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "Add a new profile", Short: "Add a new profile",
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
} }
var profileRenameCmd = &cobra.Command{
Use: "rename <profile> <new_profile_name>",
Short: "Renames an existing profile",
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
Args: cobra.ExactArgs(2),
RunE: renameProfileFunc,
}
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile>", Use: "remove <profile_name>",
Short: "Remove a profile", Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`, Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Aliases: []string{"rm"}, Args: cobra.ExactArgs(1),
Args: cobra.ExactArgs(1), RunE: removeProfileFunc,
RunE: removeProfileFunc,
} }
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile>", Use: "select <profile_name>",
Short: "Select a profile", Short: "Select a profile",
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`, Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error { func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd) SetFlagsFromEnvVars(cmd)
@@ -85,7 +65,6 @@ func setupCmd(cmd *cobra.Command) error {
return nil return nil
} }
func listProfilesFunc(cmd *cobra.Command, _ []string) error { func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil { if err := setupCmd(cmd); err != nil {
return err return err
@@ -104,33 +83,25 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn) daemonClient := proto.NewDaemonServiceClient(conn)
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{ profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username, Username: currUser.Username,
}) })
if err != nil { if err != nil {
return err return err
} }
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0) // list profiles, add a tick if the profile is active
if profileListShowID { cmd.Println("Found", len(profiles.Profiles), "profiles:")
fmt.Fprintln(tw, "ID\tNAME\tACTIVE") for _, profile := range profiles.Profiles {
} else { // use a cross to indicate the passive profiles
fmt.Fprintln(tw, "NAME\tACTIVE") activeMarker := "✗"
}
for _, profile := range resp.Profiles {
marker := ""
if profile.IsActive { if profile.IsActive {
marker = "✓" activeMarker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
id := profilemanager.ID(profile.Id)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
} }
cmd.Println(activeMarker, profile.Name)
} }
return tw.Flush()
return nil
} }
func addProfileFunc(cmd *cobra.Command, args []string) error { func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -138,90 +109,33 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err) return fmt.Errorf("connect to service CLI interface: %w", err)
} }
defer conn.Close() defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn) daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0] profileName := args[0]
id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username) _, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
if err != nil { ProfileName: profileName,
return err Username: currUser.Username,
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
func renameProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
newProfilename := args[1]
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
Handle: handle,
Username: currUser.Username,
NewProfileName: newProfilename,
}) })
if err != nil { if err != nil {
return wrapAmbiguityError(err, handle) return err
} }
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename) cmd.Println("Profile added successfully:", profileName)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
return nil return nil
} }
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error { func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil { if err := setupCmd(cmd); err != nil {
return err return err
@@ -239,17 +153,18 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
} }
daemonClient := proto.NewDaemonServiceClient(conn) daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{ profileName := args[0]
ProfileName: handle,
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username, Username: currUser.Username,
}) })
if err != nil { if err != nil {
return wrapAmbiguityError(err, handle) return err
} }
cmd.Printf("Profile removed: %s\n", resp.Id) cmd.Println("Profile removed successfully:", profileName)
return nil return nil
} }
@@ -259,7 +174,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
} }
profileManager := profilemanager.NewProfileManager() profileManager := profilemanager.NewProfileManager()
handle := args[0] profileName := args[0]
currUser, err := user.Current() currUser, err := user.Current()
if err != nil { if err != nil {
@@ -276,15 +191,32 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn) daemonClient := proto.NewDaemonServiceClient(conn)
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{ profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
ProfileName: &handle, Username: currUser.Username,
Username: &currUser.Username,
}) })
if err != nil { if err != nil {
return wrapAmbiguityError(err, handle) return fmt.Errorf("list profiles: %w", err)
} }
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil { var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
return err return err
} }
@@ -299,46 +231,6 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
} }
} }
id := profilemanager.ID(switchResp.Id) cmd.Println("Profile switched successfully to:", profileName)
cmd.Printf("Profile switched to: %s\n", id.ShortID())
return nil return nil
} }
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -190,7 +190,6 @@ func init() {
// profile commands // profile commands
profileCmd.AddCommand(profileListCmd) profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd) profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRenameCmd)
profileCmd.AddCommand(profileRemoveCmd) profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd) profileCmd.AddCommand(profileSelectCmd)

View File

@@ -1,196 +0,0 @@
//go:build privileged
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}

View File

@@ -1,12 +1,16 @@
package cmd package cmd
import ( import (
"context"
"fmt"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"syscall" "syscall"
"testing" "testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -27,6 +31,186 @@ func TestMain(m *testing.M) {
os.Exit(m.Run()) os.Exit(m.Run())
} }
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing // TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) { func TestServiceEnvVars(t *testing.T) {
tests := []struct { tests := []struct {

View File

@@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -110,10 +111,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
// Resolve the active profile's display name via the daemon, which runs pm := profilemanager.NewProfileManager()
// as root and can read the per-user profile files. The local profile var profName string
// manager only knows the active profile ID, not its display name. if activeProf, err := pm.GetActiveProfile(); err == nil {
profName := getActiveProfileName(ctx) profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{ var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@@ -165,25 +167,6 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil return resp, nil
} }
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected": case "", "idle", "connecting", "connected":

View File

@@ -128,9 +128,16 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool var profileSwitched bool
// switch profile if provided // switch profile if provided
if profileName != "" { if profileName != "" {
if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil { err = switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err) return fmt.Errorf("switch profile: %v", err)
} }
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true profileSwitched = true
} }
@@ -145,52 +152,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched) return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
} }
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error { func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided // override the default profile filepath if provided
if configPath != "" { if configPath != "" {
@@ -229,7 +190,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID) err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -300,10 +261,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
} }
// set the new config // set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username) req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil { if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable { if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon: %s", st.Message()) log.Warnf("setConfig method is not available in the daemon")
} else { } else {
return fmt.Errorf("call service setConfig method: %v", err) return fmt.Errorf("call service setConfig method: %v", err)
} }
@@ -328,11 +289,10 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err) return fmt.Errorf("setup login request: %v", err)
} }
profileID := activeProf.ID.String() loginRequest.ProfileName = &activeProf.Name
loginRequest.ProfileName = &profileID
loginRequest.Username = &username loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.ID) profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil { if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err) log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" { } else if profileState.Email != "" {
@@ -369,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
} }
if _, err := client.Up(ctx, &proto.UpRequest{ if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &profileID, ProfileName: &activeProf.Name,
Username: &username, Username: &username,
}); err != nil { }); err != nil {
return fmt.Errorf("call service up method: %v", err) return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
} }
sm := profilemanager.ServiceManager{} sm := profilemanager.ServiceManager{}
created, err := sm.AddProfile("test1", currUser.Username) err = sm.AddProfile("test1", currUser.Username)
if err != nil { if err != nil {
t.Fatalf("failed to add profile: %v", err) t.Fatalf("failed to add profile: %v", err)
return return
} }
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{ err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: created.ID, Name: "test1",
Username: currUser.Username, Username: currUser.Username,
}) })
if err != nil { if err != nil {

View File

@@ -279,11 +279,9 @@ func (c *Client) Start(startCtx context.Context) error {
select { select {
case <-startCtx.Done(): case <-startCtx.Done():
// ConnectClient.Stop now cancels its own run context and waits for the // Cancel the client context before stopping: Engine.Start blocks on the
// run loop to tear the engine down, so this cancel() is no longer // signal stream while holding the engine mutex and only unblocks on
// required to break the deadlock and could be removed. It is kept as a // cancellation. Stopping first would deadlock on that mutex.
// defensive belt-and-suspenders: cancelling the parent context first
// guarantees the run loop is unblocked even if Stop's contract regresses.
cancel() cancel()
if stopErr := client.Stop(); stopErr != nil { if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iptables package iptables
import ( import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged //go:build !android
package iptables package iptables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package nftables package nftables
import ( import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged //go:build !android
package nftables package nftables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iface package iface
import ( import (

View File

@@ -136,11 +136,6 @@ func (p *ProxyBind) CloseConn() error {
return p.close() return p.close()
} }
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
func (p *ProxyBind) InjectPacket(_ []byte) error {
return nil
}
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil { if p.remoteConn == nil {
return nil return nil

View File

@@ -219,17 +219,6 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
} }
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyWrapper) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error { func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil { if p.cancel == nil {

View File

@@ -18,9 +18,4 @@ type Proxy interface {
RedirectAs(endpoint *net.UDPAddr) RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func()) SetDisconnectListener(disconnected func())
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
InjectPacket(b []byte) error
} }

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged //go:build linux && !android
package wgproxy package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build !linux || !privileged //go:build !linux
package wgproxy package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged //go:build linux && !android
package wgproxy package wgproxy
@@ -26,6 +26,64 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
} }
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses // TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) { func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852 wgPort := 51852
@@ -198,64 +256,6 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
} }
} }
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints // TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) { func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856 wgPort := 51856

View File

@@ -147,17 +147,6 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.sendPkg = p.srcFakerConn.SendPkg p.sendPkg = p.srcFakerConn.SendPkg
} }
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *WGUDPProxy) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the localConn // CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error { func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil { if p.cancel == nil {

View File

@@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
@@ -31,13 +30,11 @@ type Manager interface {
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
ipsetCounter int ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{} routeRules map[id.RuleID]struct{}
previousConfigHash uint64 mutex sync.Mutex
hasAppliedConfig bool
mutex sync.Mutex
} }
func NewDefaultManager(fm firewall.Manager) *DefaultManager { func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -60,23 +57,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return return
} }
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -90,49 +70,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag) if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
if routeErr != nil { log.Errorf("Failed to apply route ACLs: %v", err)
log.Errorf("Failed to apply route ACLs: %v", routeErr)
} }
flushErr := d.firewall.Flush() if err := d.firewall.Flush(); err != nil {
if flushErr != nil { log.Error("failed to flush firewall rules: ", err)
log.Error("failed to flush firewall rules: ", flushErr)
} }
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
} }
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,7 +1,6 @@
package acl package acl
import ( import (
"fmt"
"net/netip" "net/netip"
"testing" "testing"
@@ -486,149 +485,3 @@ func TestPortInfoEmpty(t *testing.T) {
}) })
} }
} }
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -11,7 +11,6 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -55,10 +54,6 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
type ConnectClient struct { type ConnectClient struct {
ctx context.Context ctx context.Context
runCancel context.CancelFunc
runExited chan struct{}
runOnce sync.Once
runStarted atomic.Bool
config *profilemanager.Config config *profilemanager.Config
statusRecorder *peer.Status statusRecorder *peer.Status
@@ -75,14 +70,8 @@ func NewConnectClient(
config *profilemanager.Config, config *profilemanager.Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *ConnectClient { ) *ConnectClient {
// Derive the run context here so Stop owns the cancel that unblocks the run
// loop. runCancel is set once at construction, so Stop can call it without
// racing the run loop's startup. Callers therefore need not cancel before Stop.
runCtx, runCancel := context.WithCancel(ctx)
return &ConnectClient{ return &ConnectClient{
ctx: runCtx, ctx: ctx,
runCancel: runCancel,
runExited: make(chan struct{}),
config: config, config: config,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
engineMutex: sync.Mutex{}, engineMutex: sync.Mutex{},
@@ -146,11 +135,6 @@ func (c *ConnectClient) RunOniOS(
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
// Mark the loop as started and signal exit on return so Stop can wait for
// the loop to finish (and skip the wait if the loop never ran).
c.runStarted.Store(true)
defer c.runOnce.Do(func() { close(c.runExited) })
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -306,7 +290,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debug(err) log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
c.runCancel() _ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error return backoff.Permanent(wrapErr(err)) // unrecoverable error
} }
return wrapErr(err) return wrapErr(err)
@@ -426,10 +410,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = nil c.engine = nil
c.engineMutex.Unlock() c.engineMutex.Unlock()
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled") // todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil { if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
}
} }
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
@@ -445,12 +433,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
c.statusRecorder.ClientStart() c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx)) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
c.runCancel() _ = c.Stop()
} }
return err return err
} }
@@ -528,9 +516,11 @@ func (c *ConnectClient) Status() StatusType {
} }
func (c *ConnectClient) Stop() error { func (c *ConnectClient) Stop() error {
c.runCancel() engine := c.Engine()
if c.runStarted.Load() { if engine != nil {
<-c.runExited if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
} }
return nil return nil
} }

View File

@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key", "PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key", "SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized", "ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields", "policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
} }

View File

@@ -51,20 +51,13 @@ type cachedRecord struct {
} }
// Resolver caches critical NetBird infrastructure domains. // Resolver caches critical NetBird infrastructure domains.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all // records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// guarded by mutex.
type Resolver struct { type Resolver struct {
records map[dns.Question]*cachedRecord records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver chain ChainResolver
chainMaxPriority int chainMaxPriority int
refreshGroup singleflight.Group refreshGroup singleflight.Group
@@ -83,10 +76,9 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question]*cachedRecord), records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool), refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time), cacheTTL: resolveCacheTTL(),
cacheTTL: resolveCacheTTL(),
} }
} }
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache. // AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale // A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype. When one family hard-errors while the other succeeds, // entry for that qtype.
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords)) d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil return nil
} }
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA) delete(m.records, qAAAA)
delete(m.refreshing, qA) delete(m.refreshing, qA)
delete(m.refreshing, qAAAA) delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains() currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains) removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
} }
m.addNewDomains(ctx, newDomains) m.addNewDomains(ctx, newDomains)
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain return m.mgmtDomain != nil && domain == *m.mgmtDomain
} }
// addNewDomains resolves and caches domains that are not yet in the cache, // addNewDomains resolves and caches all domains from the update
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains { for _, newDomain := range newDomains {
if _, dup := seen[newDomain]; dup { if err := m.AddDomain(ctx, newDomain); err != nil {
continue log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} } else {
seen[newDomain] = struct{}{} log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
} }
} }
} }

View File

@@ -21,7 +21,6 @@ type fakeChain struct {
mu sync.Mutex mu sync.Mutex
calls map[string]int calls map[string]int
answers map[string][]dns.RR answers map[string][]dns.RR
qErr map[string]error
err error err error
hasRoot bool hasRoot bool
onLookup func() onLookup func()
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
return &fakeChain{ return &fakeChain{
calls: map[string]int{}, calls: map[string]int{},
answers: map[string][]dns.RR{}, answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true, hasRoot: true,
} }
} }
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++ f.calls[key]++
answers := f.answers[key] answers := f.answers[key]
err := f.err err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup onLookup := f.onLookup
f.mu.Unlock() f.mu.Unlock()
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
} }
} }
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int { func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()

View File

@@ -1,183 +0,0 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -8,7 +8,6 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"slices"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -168,10 +167,7 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA: case dns.TypeA:
alternativeNetwork = "ip6" alternativeNetwork = "ip6"
default: default:
// Non-address types reach LookupIP only unexpectedly; without an return dns.RcodeNameError
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
} }
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil { if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -188,230 +184,6 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess return dns.RcodeSuccess
} }
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging. // FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string { func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 { if len(answers) == 0 {
@@ -435,35 +207,3 @@ func FormatAnswers(answers []dns.RR) string {
} }
return "[" + strings.Join(parts, ", ") + "]" return "[" + strings.Join(parts, ", ") + "]"
} }
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
// RFC 6891 a responder must not include an OPT RR toward a client that did not
// advertise EDNS0.
func StripOPT(msg *dns.Msg) {
if len(msg.Extra) == 0 {
return
}
out := msg.Extra[:0]
for _, rr := range msg.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
msg.Extra = out
}
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
// the message, if present.
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
opt := msg.IsEdns0()
if opt == nil {
return nil, false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok {
return ede, true
}
}
return nil, false
}

View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -121,200 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL") assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
} }
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
StripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestExtractEDE(t *testing.T) {
t.Run("no edns", func(t *testing.T) {
_, ok := ExtractEDE(&dns.Msg{})
assert.False(t, ok, "message without OPT has no EDE")
})
t.Run("edns without ede", func(t *testing.T) {
rm := &dns.Msg{}
rm.SetEdns0(4096, false)
_, ok := ExtractEDE(rm)
assert.False(t, ok, "OPT without EDE option returns false")
})
t.Run("with ede", func(t *testing.T) {
rm := &dns.Msg{}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
rm.Extra = append(rm.Extra, opt)
ede, ok := ExtractEDE(rm)
assert.True(t, ok, "EDE option should be found")
assert.Equal(t, uint16(49152), ede.InfoCode)
assert.Equal(t, "upstream timeout", ede.ExtraText)
})
}

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -39,15 +38,11 @@ const (
// defaultWarningDelayBase is the starting grace window before a // defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's // "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no // never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable via envWarningDelay; // Connected peer. Per-server and overridable; see warningDelayFor.
// see warningDelay. defaultWarningDelayBase = 30 * time.Second
defaultWarningDelayBase = 60 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the // warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelay. // base grace window. See warningDelayFor.
warningDelayBonusCap = 30 * time.Second warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
) )
// errNoUsableNameservers signals that a merged-domain group has no usable // errNoUsableNameservers signals that a merged-domain group has no usable
@@ -140,7 +135,7 @@ type DefaultServer struct {
disableSys bool disableSys bool
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxHandlers []handlerWrapper dnsMuxMap registeredHandlerMap
localResolver *local.Resolver localResolver *local.Resolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
@@ -204,6 +199,8 @@ type handlerWrapper struct {
priority int priority int
} }
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer // DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct { type DefaultServerConfig struct {
WgInterface WGIface WgInterface WGIface
@@ -292,6 +289,7 @@ func newDefaultServer(
service: dnsService, service: dnsService,
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
@@ -300,7 +298,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver, mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: warningDelayBaseFromEnv(), warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1), healthRefresh: make(chan struct{}, 1),
} }
// Wire the local resolver against the peer status recorder so it can // Wire the local resolver against the peer status recorder so it can
@@ -330,7 +328,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface { type routeSettable interface {
setSelectedRoutes(func() route.HAMap) setSelectedRoutes(func() route.HAMap)
} }
for _, entry := range s.dnsMuxHandlers { for _, entry := range s.dnsMuxMap {
if h, ok := entry.handler.(routeSettable); ok { if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected) h.setSelectedRoutes(selected)
} }
@@ -980,23 +978,19 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxHandlers { for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler([]string{existing.domain}, existing.priority)
// The local resolver is a persistent singleton shared by every custom existing.handler.Stop()
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
} }
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
} }
s.dnsMuxHandlers = muxUpdates s.dnsMuxMap = muxUpdateMap
} }
// updateNSGroupStates records the new group set and pokes the refresher. // updateNSGroupStates records the new group set and pokes the refresher.
@@ -1160,26 +1154,6 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false return false
} }
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route // warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by // count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows // warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1230,7 +1204,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler. // in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth) merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxHandlers { for _, entry := range s.dnsMuxMap {
reporter, ok := entry.handler.(upstreamHealthReporter) reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok { if !ok {
continue continue

View File

@@ -1,485 +0,0 @@
//go:build privileged
package dns
import (
"context"
"fmt"
"net/netip"
"os"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
@@ -102,6 +104,481 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
}
for key := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key]
if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) { func TestDNSServerStartStop(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -552,15 +1029,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {} func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) { func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := []handlerWrapper{ baseMatchHandlers := registeredHandlerMap{
{ "upstream-group1": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
{ "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
@@ -569,15 +1046,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}, },
} }
baseRootHandlers := []handlerWrapper{ baseRootHandlers := registeredHandlerMap{
{ "upstream-root1": {
domain: ".", domain: ".",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-root1", Id: "upstream-root1",
}, },
priority: PriorityDefault, priority: PriorityDefault,
}, },
{ "upstream-root2": {
domain: ".", domain: ".",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-root2", Id: "upstream-root2",
@@ -586,22 +1063,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}, },
} }
baseMixedHandlers := []handlerWrapper{ baseMixedHandlers := registeredHandlerMap{
{ "upstream-group1": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
{ "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityUpstream - 1,
}, },
{ "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
@@ -612,7 +1089,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
initialHandlers []handlerWrapper initialHandlers registeredHandlerMap
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain expectedHandlers map[string]string // map[HandlerID]domain
description string description string
@@ -896,38 +1373,32 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{ server := &DefaultServer{
dnsMuxHandlers: tt.initialHandlers, dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
service: &mockService{}, service: &mockService{},
} }
// Perform the update // Perform the update
server.updateMux(tt.updates) server.updateMux(tt.updates)
// Verify the results // Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers), assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
"Number of handlers after update doesn't match expected") "Number of handlers after update doesn't match expected")
// Check each expected handler // Check each expected handler
for id, expectedDomain := range tt.expectedHandlers { for id, expectedDomain := range tt.expectedHandlers {
var found *handlerWrapper handler, exists := server.dnsMuxMap[types.HandlerID(id)]
for i := range server.dnsMuxHandlers { assert.True(t, exists, "Expected handler %s not found", id)
if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) { if exists {
found = &server.dnsMuxHandlers[i] assert.Equal(t, expectedDomain, handler.domain,
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
"Domain mismatch for handler %s", id) "Domain mismatch for handler %s", id)
} }
} }
// Verify no unexpected handlers exist // Verify no unexpected handlers exist
for _, entry := range server.dnsMuxHandlers { for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(entry.handler.ID())] _, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID()) assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
} }
// Verify the handlerChain state and order // Verify the handlerChain state and order
@@ -942,7 +1413,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux // Verify handler exists in mux
foundInMux := false foundInMux := false
for _, muxEntry := range server.dnsMuxHandlers { for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler && if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority && chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -951,108 +1422,12 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
} }
} }
assert.True(t, foundInMux, assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxHandlers") "Handler in chain not found in dnsMuxMap")
} }
}) })
} }
} }
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) { func TestExtraDomains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -1674,6 +2049,7 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
} }
groups := []*nbdns.NameServerGroup{ groups := []*nbdns.NameServerGroup{
@@ -1831,7 +2207,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
} }
} }
// healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed // healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates // UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers. // without spinning up real handlers.
type healthStubHandler struct { type healthStubHandler struct {
@@ -1907,11 +2283,12 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(), ctx: context.Background(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected }, selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active }, activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase, warningDelayBase: defaultWarningDelayBase,
} }
fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}} fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock() fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2018,6 +2395,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil }, selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond, warningDelayBase: 50 * time.Millisecond,
@@ -2029,7 +2407,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}} }}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}} server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2066,6 +2444,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface), service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{}, extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"), statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil }, selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
@@ -2080,7 +2459,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
} }
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}} server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2105,32 +2484,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer // rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No // comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either. // warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t) fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2242,6 +2595,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{ server := &DefaultServer{
ctx: context.Background(), ctx: context.Background(),
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap }, selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour, warningDelayBase: time.Hour,
@@ -2259,7 +2613,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"}, overlay: {LastFail: time.Now(), LastErr: "timeout"},
}, },
} }
server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}} server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2286,6 +2640,7 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
} }
tests := []struct { tests := []struct {

View File

@@ -443,32 +443,29 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
} }
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := "" proto := ""
if upstreamProto != nil { if upstreamProto != nil {
proto = upstreamProto.protocol proto = upstreamProto.protocol
} }
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok { if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns { if !hadEdns {
resutil.StripOPT(rm) stripOPT(rm)
} }
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
} }
reason := dns.RcodeToString[rm.Rcode] reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
} }
if !hadEdns { if !hadEdns {
resutil.StripOPT(rm) stripOPT(rm)
} }
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
} }
@@ -523,6 +520,22 @@ func upstreamUDPSize() uint16 {
return dns.MinMsgSize return dns.MinMsgSize
} }
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()} return &upstreamFailure{upstream: upstream, reason: err.Error()}

View File

@@ -517,78 +517,6 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
} }
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) { func TestFormatFailures(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -985,6 +913,19 @@ func TestEDEName(t *testing.T) {
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric") assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
} }
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53") upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53") upstream2 := netip.MustParseAddrPort("192.0.2.2:53")

View File

@@ -26,23 +26,8 @@ import (
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
// EDE info codes the forwarder emits on upstream failures so the querying
// client can see the reason without inspecting this peer's logs. They live in
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
// real upstream EDE here, so these cannot collide with a genuine code.
const (
edeNetbirdUpstreamTimeout uint16 = 49152
edeNetbirdUpstreamFailure uint16 = 49153
)
type resolver interface { type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
} }
type firewaller interface { type firewaller interface {
@@ -216,6 +201,12 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query) resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, ".")) mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" { if mostSpecificResId == "" {
@@ -227,46 +218,9 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil { if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime) f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return return
} }
@@ -277,25 +231,6 @@ func (f *DNSForwarder) handleAddressQuery(
f.writeResponse(logger, w, resp, qname, startTime) f.writeResponse(logger, w, resp, qname, startTime)
} }
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) { func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
@@ -398,7 +333,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg, resp *dns.Msg,
domain string, domain string,
result resutil.LookupResult, result resutil.LookupResult,
reqHasEdns bool,
startTime time.Time, startTime time.Time,
) { ) {
qType := question.Qtype qType := question.Qtype
@@ -440,10 +374,6 @@ func (f *DNSForwarder) handleDNSError(
logger.Warnf(errResolveFailed, domain, result.Err) logger.Warnf(errResolveFailed, domain, result.Err)
} }
if reqHasEdns {
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
}
f.writeResponse(logger, w, resp, domain, startTime) f.writeResponse(logger, w, resp, domain, startTime)
} }
@@ -484,33 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
func edeCodeFor(dnsErr *net.DNSError) uint16 {
if dnsErr != nil && dnsErr.IsTimeout {
return edeNetbirdUpstreamTimeout
}
return edeNetbirdUpstreamFailure
}
// edeText builds the EDE extra-text describing the class of upstream failure.
// It deliberately omits the upstream server address, which may be an internal
// resolver and is exposed to any client permitted to use the route; the full
// detail stays in the forwarder's local log.
func edeText(dnsErr *net.DNSError) string {
if dnsErr != nil && dnsErr.IsTimeout {
return "netbird forwarder: upstream timeout"
}
return "netbird forwarder: upstream failure"
}
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -133,41 +132,6 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1) return args.Get(0).([]netip.Addr), args.Error(1)
} }
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) { func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -580,15 +544,12 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
} }
func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct { tests := []struct {
name string name string
queryType uint16 queryType uint16
queryDomain string queryDomain string
configured string configured string
expectedCode int expectedCode int
expectEDE bool
description string description string
}{ }{
{ {
@@ -600,13 +561,28 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries", description: "RFC compliant REFUSED for unauthorized queries",
}, },
{ {
name: "unsupported query type returns NODATA", name: "unsupported query type returns NOTIMP",
queryType: dns.TypeCAA, queryType: dns.TypeMX,
queryDomain: "example.com", queryDomain: "example.com",
configured: "example.com", configured: "example.com",
expectedCode: dns.RcodeSuccess, expectedCode: dns.RcodeNotImplemented,
expectEDE: true, description: "RFC compliant NOTIMP for unsupported types",
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP", },
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
}, },
} }
@@ -622,7 +598,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{} query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType) query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response // Capture the written response
var writtenResp *dns.Msg var writtenResp *dns.Msg
@@ -638,288 +613,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer // Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written") require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string
lookupErr error
reqEdns bool
wantEDE bool
wantCode uint16
wantTextHas string
}{
{
name: "timeout with edns0",
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamTimeout,
wantTextHas: "netbird forwarder: upstream timeout",
},
{
name: "server failure with edns0",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamFailure,
wantTextHas: "netbird forwarder: upstream failure",
},
{
name: "no edns0 in request omits ede",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: false,
wantEDE: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr(nil), tt.lookupErr).Once()
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
if tt.reqEdns {
query.SetEdns0(dns.DefaultMsgSize, false)
}
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
mockResolver.AssertExpectations(t)
require.NotNil(t, writtenResp, "expected a response")
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
ede, ok := resutil.ExtractEDE(writtenResp)
if !tt.wantEDE {
assert.False(t, ok, "response must not carry EDE")
return
}
require.True(t, ok, "response must carry EDE")
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
}) })
} }
} }

View File

@@ -82,18 +82,10 @@ const (
PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms PeerConnectionTimeoutMin = 30000 // ms
disableAutoUpdate = "disabled" disableAutoUpdate = "disabled"
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
// check gathering. The gathering runs uncancellable system calls (process scan,
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
systemInfoTimeout = 15 * time.Second
) )
var ErrResetConnection = fmt.Errorf("reset connection") var ErrResetConnection = fmt.Errorf("reset connection")
var ErrEngineAlreadyStarted = errors.New("engine already started")
type EngineConfig struct { type EngineConfig struct {
WgPort int WgPort int
WgIfaceName string WgIfaceName string
@@ -207,8 +199,6 @@ type Engine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
started bool
wgInterface WGIface wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
@@ -216,12 +206,6 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
// forwardingRules holds the ingress forward rules applied for the current target.
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
// it is stashed here so the final, peer-converged pass can build the lazy-connection
// exclude list without recomputing them on every bounded peer pass.
forwardingRules []firewallManager.ForwardRule
networkMonitor *networkmonitor.NetworkMonitor networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer sshServer sshServer
@@ -295,15 +279,9 @@ func NewEngine(
services EngineServices, services EngineServices,
mobileDep MobileDependency, mobileDep MobileDependency,
) *Engine { ) *Engine {
// The engine is single-use: a fresh instance is built per connection
// cycle (see Client.run), so the run context is created once here rather
// than in Start.
ctx, cancel := context.WithCancel(clientCtx)
engine := &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: services.SignalClient, signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient, mgmClient: services.MgmClient,
@@ -336,34 +314,8 @@ func (e *Engine) Stop() error {
log.Debugf("tried stopping engine that is nil") log.Debugf("tried stopping engine that is nil")
return nil return nil
} }
e.cancel()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
e.stopLocked()
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// stopLocked tears down everything Start may have brought up, in the order
// teardown requires (DNS before the interface goes down, flow manager after).
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
// path, so a partially-initialized engine is cleaned up the same way; every
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
// after releasing the lock, since the goroutines also take syncMsgMux.
func (e *Engine) stopLocked() {
if e.connMgr != nil { if e.connMgr != nil {
e.connMgr.Close() e.connMgr.Close()
} }
@@ -414,6 +366,10 @@ func (e *Engine) stopLocked() {
// so dbus and friends don't complain because of a missing interface // so dbus and friends don't complain because of a missing interface
e.stopDNSServer() e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish e.jobExecutorWG.Wait() // block until job goroutines finish
e.close() e.close()
@@ -432,6 +388,21 @@ func (e *Engine) stopLocked() {
if err := e.stateManager.PersistState(context.Background()); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
} }
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. // calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
@@ -469,38 +440,18 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// The engine is single-use. Reject a duplicate start and a start on an if err := iface.ValidateMTU(e.config.MTU); err != nil {
// already-stopped engine (run context cancelled).
if e.started {
return ErrEngineAlreadyStarted
}
if ctxErr := e.ctx.Err(); ctxErr != nil {
return fmt.Errorf("engine already stopped: %w", ctxErr)
}
e.started = true
// Tear down any partially-initialized state on a failed start. Cancel the
// run context first so goroutines started before the failure (connMgr,
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
// not just what close() covers.
defer func() {
if err != nil {
e.cancel()
e.stopLocked()
}
}()
if err = iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err) return fmt.Errorf("invalid MTU configuration: %w", err)
} }
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient) e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
@@ -534,11 +485,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings() initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil { if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err) return fmt.Errorf("read initial settings: %w", err)
} }
dnsServer, err := e.newDnsServer(dnsConfig) dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil { if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err) return fmt.Errorf("create dns server: %w", err)
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
@@ -573,6 +526,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err = e.wgInterfaceCreate(); err != nil { if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
} }
@@ -581,6 +535,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
} }
if err := e.createFirewall(); err != nil { if err := e.createFirewall(); err != nil {
e.close()
return err return err
} }
@@ -592,6 +547,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.udpMux, err = e.wgInterface.Up() e.udpMux, err = e.wgInterface.Up()
if err != nil { if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
@@ -616,7 +572,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
if err := e.dnsServer.Initialize(); err != nil { err = e.dnsServer.Initialize()
if err != nil {
e.close()
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
@@ -628,9 +586,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed()) e.srWatcher.Start(peer.IsForceRelayed())
if err = e.receiveSignalEvents(); err != nil { e.receiveSignalEvents()
return err
}
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveJobEvents() e.receiveJobEvents()
@@ -682,6 +638,7 @@ func (e *Engine) createFirewall() error {
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil { if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err) return fmt.Errorf("set firewall: %w", err)
} }
@@ -774,15 +731,7 @@ func (e *Engine) blockLanAccess() {
// modifyPeers updates peers that have been modified (e.g. IP address has been changed). // modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one. // It closes the existing connection, removes it from the peerConns map, and creates a new one.
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
// batch at a time and other subsystems can interleave between passes. It is
// passed in (not read globally) so tests can exercise the multi-pass path.
const maxPeersPerSyncPass = 300
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
// when more changed peers remained than the cap, so the caller re-runs.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
// first, check if peers have been modified // first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
@@ -812,32 +761,26 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
} }
} }
more := false
if len(modified) > maxBatch {
modified = modified[:maxBatch]
more = true
}
// second, close all modified connections and remove them from the state map // second, close all modified connections and remove them from the state map
for _, p := range modified { for _, p := range modified {
if err := e.removePeer(p.GetWgPubKey()); err != nil { err := e.removePeer(p.GetWgPubKey())
return false, err if err != nil {
return err
} }
} }
// third, add the peer connections again // third, add the peer connections again
for _, p := range modified { for _, p := range modified {
if err := e.addNewPeer(p); err != nil { err := e.addNewPeer(p)
return false, err if err != nil {
return err
} }
} }
return more, nil return nil
} }
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
// removePeers removes up to maxBatch peers per call. It returns true when more func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// peers remained to remove than the cap, so the caller re-runs.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
newPeers := make([]string, 0, len(peersUpdate)) newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate { for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey()) newPeers = append(newPeers, p.GetWgPubKey())
@@ -845,19 +788,14 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
more := false
if len(toRemove) > maxBatch {
toRemove = toRemove[:maxBatch]
more = true
}
for _, p := range toRemove { for _, p := range toRemove {
if err := e.removePeer(p); err != nil { err := e.removePeer(p)
return false, err if err != nil {
return err
} }
log.Infof("removed peer %s", p) log.Infof("removed peer %s", p)
} }
return more, nil return nil
} }
func (e *Engine) removeAllPeers() error { func (e *Engine) removeAllPeers() error {
@@ -926,38 +864,27 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate) e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
} }
// phase times a sync sub-phase: it returns a function that records the elapsed func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// duration when called. Starting the timer at the call site keeps inter-phase started := time.Now()
// glue code out of the measurement. defer func() {
func (e *Engine) phase(name string) func() { duration := time.Since(started)
start := time.Now() log.Infof("sync finished in %s", duration)
return func() { e.clientMetrics.RecordSyncDuration(e.ctx, duration)
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start)) }()
}
}
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
// returns true if more peers remained than the per-pass cap. It is driven by the
// mapStateManager, which re-invokes it (releasing the lock between passes) until
// the update is fully applied.
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown // Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil { if e.ctx.Err() != nil {
return false, e.ctx.Err() return e.ctx.Err()
} }
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
} }
done := e.phase("netbird_config") if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
err := e.updateNetbirdConfig(update.GetNetbirdConfig()) return err
done()
if err != nil {
return false, err
} }
// Posture checks are bound to the network map presence: // Posture checks are bound to the network map presence:
@@ -967,25 +894,23 @@ func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (b
// leave the previously applied checks untouched // leave the previously applied checks untouched
nm := update.GetNetworkMap() nm := update.GetNetworkMap()
if nm == nil { if nm == nil {
return false, nil return nil
} }
done = e.phase("checks") if err := e.updateChecksIfNew(update.Checks); err != nil {
err = e.updateChecksIfNew(update.Checks) return err
done()
if err != nil {
return false, err
} }
e.persistSyncResponse(update)
// only apply new changes and ignore old ones // only apply new changes and ignore old ones
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass) if err := e.updateNetworkMap(nm); err != nil {
if err != nil { return err
return false, err
} }
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil) e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
return more, nil return nil
} }
// updateNetbirdConfig applies the management-provided NetBird configuration: // updateNetbirdConfig applies the management-provided NetBird configuration:
@@ -1031,13 +956,6 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled / // (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed. // engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) { func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
// Only persist updates that carry a network map. Config-only updates (e.g. relay
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
// the last full map on disk and break restore-on-restart.
if update.GetNetworkMap() == nil {
return
}
e.syncRespMux.RLock() e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock() defer e.syncRespMux.RUnlock()
@@ -1117,22 +1035,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
} }
e.checks = checks e.checks = checks
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...) info, err := system.GetInfoWithChecks(e.ctx, checks)
if !ok { if err != nil {
// Gathering timed out; skip the meta sync this cycle rather than blocking the log.Warnf("failed to get system info with checks: %v", err)
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry. info = system.GetInfo(e.ctx)
return nil
} }
e.applyInfoFlags(info)
if err := e.mgmClient.SyncMeta(info); err != nil {
return fmt.Errorf("could not sync meta: error %s", err)
}
return nil
}
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
func (e *Engine) applyInfoFlags(info *system.Info) {
info.SetFlags( info.SetFlags(
e.config.RosenpassEnabled, e.config.RosenpassEnabled,
e.config.RosenpassPermissive, e.config.RosenpassPermissive,
@@ -1151,20 +1058,12 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
e.config.EnableSSHRemotePortForwarding, e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth, e.config.DisableSSHAuth,
) )
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it if err := e.mgmClient.SyncMeta(info); err != nil {
// can be excluded from the reported network addresses; the interface coming and log.Errorf("could not sync meta: error %s", err)
// going otherwise churns the peer meta on the management server. return err
func (e *Engine) overlayAddresses() []netip.Addr {
var ips []netip.Addr
if e.config.WgAddr.IP.IsValid() {
ips = append(ips, e.config.WgAddr.IP)
} }
if e.config.WgAddr.HasIPv6() { return nil
ips = append(ips, e.config.WgAddr.IPv6)
}
return ips
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -1310,32 +1209,31 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1) e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done() defer e.shutdownWg.Done()
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...) info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if !ok { if err != nil {
// Gathering timed out; connect the stream with base info so management log.Warnf("failed to get system info with checks: %v", err)
// connectivity still comes up rather than blocking here.
info = system.GetInfo(e.ctx) info = system.GetInfo(e.ctx)
} }
e.applyInfoFlags(info) info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
// The map-state manager converges the latest update in the background in err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
// bounded passes; the stream callback only hands it the newest target.
persist := func(u *mgmProto.SyncResponse) {
done := e.phase("persist")
e.persistSyncResponse(u)
done()
}
manager := newMapStateManager(e.applySyncPass, persist, func(d time.Duration) {
log.Infof("sync finished in %s", d)
e.clientMetrics.RecordSyncDuration(e.ctx, d)
})
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
manager.run(e.ctx)
}()
err := e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
@@ -1386,107 +1284,21 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
return nil return nil
} }
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// and up to maxBatch peers per phase. It returns true when more peers remained
// than the cap, so the caller re-runs until convergence.
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil { if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig()) err := e.updateConfig(networkMap.GetPeerConfig())
if err != nil { if err != nil {
return false, err return err
} }
} }
serial := networkMap.GetSerial() serial := networkMap.GetSerial()
if e.networkSerial > serial { if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial) log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
return false, nil return nil
} }
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
// up-front and only once per target: they are cheap, local, idempotent and must
// be in place before peers come up (fail-closed). On the bounded re-runs that only
// drain the remaining peer batches they are skipped — the applied forward rules are
// reused from e.forwardingRules for the lazy-exclude finalize.
if firstPass {
e.applyWholesale(networkMap, serial)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
doneOffline := e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
doneOffline()
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// No special case for cleanup: when management signals RemotePeersIsEmpty (e.g. our
// peer was deleted), remotePeers is already empty, so the bounded diff below removes
// every peer in batches — same path as a normal update, no unbounded removeAllPeers
// held under syncMsgMux in one shot.
doneRemoved := e.phase("removed_peers")
removeMore, err := e.removePeers(remotePeers, maxBatch)
doneRemoved()
if err != nil {
return false, err
}
doneModified := e.phase("modified_peers")
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
doneModified()
if err != nil {
return false, err
}
doneAdded := e.phase("added_peers")
addMore, err := e.addNewPeers(remotePeers, maxBatch)
doneAdded()
if err != nil {
return false, err
}
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
needMore := removeMore || modifyMore || addMore
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// Set the exclude list only once peers have fully converged (this pass added
// the last batch). It needs all target peers present in the store, and
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
if !needMore {
doneLazy := e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
doneLazy()
}
e.networkSerial = serial
return needMore, nil
}
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
// ingress forward rules — that must be in place before peers come up. It runs once
// per target (first pass only); the resulting forward rules are stashed in
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil { if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err) log.Errorf("failed to update lazy connection feature flag: %v", err)
} }
@@ -1514,16 +1326,13 @@ func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64)
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address()) dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1532,34 +1341,79 @@ func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64)
e.connMgr.UpdateRouteHAMap(clientRoutes) e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes)) log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
} }
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err) log.Errorf("failed to update routes: %v", err)
} }
done()
done = e.phase("filtering")
if e.acl != nil { if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag) e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
} }
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules // Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil { if err != nil {
log.Errorf("failed to update forward rules, err: %v", err) log.Errorf("failed to update forward rules, err: %v", err)
} }
done()
e.forwardingRules = forwardingRules log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return err
}
} else {
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial
return nil
} }
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool { func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1739,23 +1593,14 @@ func addrToString(addr netip.Addr) string {
} }
// addNewPeers adds peers that were not know before but arrived from the Management service with the update // addNewPeers adds peers that were not know before but arrived from the Management service with the update
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// when more new peers remained than the cap, so the caller re-runs.
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
added := 0
for _, p := range peersUpdate { for _, p := range peersUpdate {
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok { err := e.addNewPeer(p)
continue // already present (cheap skip), does not count toward the cap if err != nil {
return err
} }
if added >= maxBatch {
return true, nil // at least one more new peer remains
}
if err := e.addNewPeer(p); err != nil {
return false, err
}
added++
} }
return false, nil return nil
} }
// addNewPeer add peer if connection doesn't exist // addNewPeer add peer if connection doesn't exist
@@ -1853,7 +1698,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
} }
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() error { func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1) e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done() defer e.shutdownWg.Done()
@@ -1869,13 +1714,6 @@ func (e *Engine) receiveSignalEvents() error {
return e.ctx.Err() return e.ctx.Err()
} }
// Self-addressed heartbeat: the signal client's receive watchdog
// round-trips this through the server to confirm the receive stream
// is delivering. Liveness is already recorded before this handler.
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
return nil
}
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1924,12 +1762,7 @@ func (e *Engine) receiveSignalEvents() error {
} }
}() }()
// todo: consider to remove this blocker. I do not see benefit to block the Start operations e.signal.WaitStreamConnected()
e.signal.WaitStreamConnected(e.ctx)
if err := e.ctx.Err(); err != nil {
return fmt.Errorf("wait for signal stream: %w", err)
}
return nil
} }
func (e *Engine) parseNATExternalIPMappings() []string { func (e *Engine) parseNATExternalIPMappings() []string {

View File

@@ -1,565 +0,0 @@
//go:build privileged
package internal
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers were started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}

View File

@@ -6,18 +6,37 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -31,7 +50,18 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client" mgmt "github.com/netbirdio/netbird/shared/management/client"
@@ -39,9 +69,25 @@ import (
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
type MockWGIface struct { type MockWGIface struct {
CreateFunc func() error CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
@@ -178,10 +224,6 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil return nil
} }
func (m *MockWGIface) MTU() uint16 {
return 1280
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error { func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil return nil
} }
@@ -192,6 +234,129 @@ func TestMain(m *testing.M) {
os.Exit(code) os.Exit(code)
} }
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) { func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config // Test that SSH server start/stop logic works based on config
engine := &Engine{ engine := &Engine{
@@ -261,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(CtxInitState(context.Background())) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -437,7 +602,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} { for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true) err = engine.updateNetworkMap(c.networkMap)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -464,47 +629,97 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
} }
}) })
} }
}
// chunked apply: with a per-pass cap smaller than the number of peers, a func TestEngine_Sync(t *testing.T) {
// single updateNetworkMap applies one batch and reports more==true; the key, err := wgtypes.GeneratePrivateKey()
// caller re-runs until convergence. (engine currently holds 0 peers.) if err != nil {
t.Run("chunked add converges over multiple passes", func(t *testing.T) { t.Fatal(err)
nm := &mgmtProto.NetworkMap{ return
Serial: 6, }
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
} }
more, err := engine.updateNetworkMap(nm, 1, true) if getPeers(engine) == 3 && engine.networkSerial == 10 {
require.NoError(t, err) break
require.True(t, more, "pass 1 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.True(t, more, "pass 2 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 2)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 3 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 3)
})
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 7,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
} }
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 2 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
})
} }
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
@@ -602,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(CtxInitState(context.Background())) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -675,7 +890,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
} }
}() }()
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true) err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match") assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
@@ -809,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(CtxInitState(context.Background())) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -879,7 +1094,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
}() }()
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true) err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match") assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
@@ -890,6 +1105,104 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
} }
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers was started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
func Test_ParseNATExternalIPMappings(t *testing.T) { func Test_ParseNATExternalIPMappings(t *testing.T) {
ifaceList, err := net.Interfaces() ifaceList, err := net.Interfaces()
if err != nil { if err != nil {
@@ -1213,6 +1526,187 @@ func TestCompareNetIPLists(t *testing.T) {
} }
} }
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte { func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper() t.Helper()
b, err := netiputil.EncodePrefix(p) b, err := netiputil.EncodePrefix(p)

View File

@@ -44,5 +44,4 @@ type wgIfaceBase interface {
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
MTU() uint16
} }

View File

@@ -119,16 +119,15 @@ func (d *BindListener) ReadPackets() {
} }
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close() _ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP) d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done() d.done.Done()
} }
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
func (d *BindListener) CapturedPacket() []byte {
return nil
}
// Close stops the listener and cleans up resources. // Close stops the listener and cleans up resources.
func (d *BindListener) Close() { func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)") d.peerCfg.Log.Infof("closing activity listener (LazyConn)")

View File

@@ -45,6 +45,10 @@ type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager endpointMgr *mockEndpointManager
} }
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil return nil
} }
@@ -64,10 +68,6 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr return m.endpointMgr
} }
func (m *MockWGIfaceBind) MTU() uint16 {
return 1280
}
func TestBindListener_Creation(t *testing.T) { func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager() mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -207,9 +207,8 @@ func TestManager_BindMode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
select { select {
case ev := <-mgr.OnActivityChan: case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match") assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification") t.Fatal("timeout waiting for activity notification")
} }
@@ -267,8 +266,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
receivedPeers := make(map[peerid.ConnID]bool) receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
select { select {
case ev := <-mgr.OnActivityChan: case peerConnID := <-mgr.OnActivityChan:
receivedPeers[ev.PeerConnID] = true receivedPeers[peerConnID] = true
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications") t.Fatal("timeout waiting for activity notifications")
} }

View File

@@ -3,13 +3,11 @@ package activity
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
) )
@@ -22,8 +20,6 @@ type UDPListener struct {
done sync.Mutex done sync.Mutex
isClosed atomic.Bool isClosed atomic.Bool
capturedPacket []byte
} }
// NewUDPListener creates a listener that detects activity via UDP socket reads. // NewUDPListener creates a listener that detects activity via UDP socket reads.
@@ -50,13 +46,9 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
} }
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. // ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
// The first packet that triggers activity is captured so it can be reinjected through the real
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
// dropped and WG would only retry after REKEY_TIMEOUT.
func (d *UDPListener) ReadPackets() { func (d *UDPListener) ReadPackets() {
for { for {
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead) n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
if err != nil { if err != nil {
if d.isClosed.Load() { if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener") d.peerCfg.Log.Infof("exit from activity listener")
@@ -70,24 +62,20 @@ func (d *UDPListener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr) d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue continue
} }
d.capturedPacket = slices.Clone(buf[:n]) d.peerCfg.Log.Infof("activity detected")
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
break break
} }
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint; d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
// removing the peer here wipes kernel WG's staged queue and drops the user packet that if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
// triggered activation. d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Ignore close error as it may return "use of closed network connection" if already closed.
_ = d.conn.Close() _ = d.conn.Close()
d.done.Unlock() d.done.Unlock()
} }
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *UDPListener) CapturedPacket() []byte {
return d.capturedPacket
}
// Close stops the listener and cleans up resources. // Close stops the listener and cleans up resources.
func (d *UDPListener) Close() { func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())

View File

@@ -19,25 +19,17 @@ import (
type listener interface { type listener interface {
ReadPackets() ReadPackets()
Close() Close()
CapturedPacket() []byte
}
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
// captured for reinjection through the real transport.
type Event struct {
PeerConnID peerid.ConnID
FirstPacket []byte
} }
type WgInterface interface { type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool IsUserspaceBind() bool
Address() wgaddr.Address Address() wgaddr.Address
MTU() uint16
} }
type Manager struct { type Manager struct {
OnActivityChan chan Event OnActivityChan chan peerid.ConnID
wgIface WgInterface wgIface WgInterface
@@ -49,7 +41,7 @@ type Manager struct {
func NewManager(wgIface WgInterface) *Manager { func NewManager(wgIface WgInterface) *Manager {
m := &Manager{ m := &Manager{
OnActivityChan: make(chan Event, 1), OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface, wgIface: wgIface,
peers: make(map[peerid.ConnID]listener), peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}), done: make(chan struct{}),
@@ -124,12 +116,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
delete(m.peers, peerConnID) delete(m.peers, peerConnID)
m.mu.Unlock() m.mu.Unlock()
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()}) m.notify(peerConnID)
} }
func (m *Manager) notify(ev Event) { func (m *Manager) notify(peerConnID peerid.ConnID) {
select { select {
case <-m.done: case <-m.done:
case m.OnActivityChan <- ev: case m.OnActivityChan <- peerConnID:
} }
} }

View File

@@ -1,7 +1,6 @@
package activity package activity
import ( import (
"bytes"
"net" "net"
"net/netip" "net/netip"
"testing" "testing"
@@ -26,6 +25,10 @@ func (m *MocPeer) ConnID() peerid.ConnID {
type MocWGIface struct { type MocWGIface struct {
} }
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil return nil
} }
@@ -41,10 +44,6 @@ func (m MocWGIface) Address() wgaddr.Address {
} }
} }
func (m MocWGIface) MTU() uint16 {
return 1280
}
// GetPeerListener is a test helper to access listeners // GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock() m.mu.Lock()
@@ -87,15 +86,11 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
} }
select { select {
case ev := <-mgr.OnActivityChan: case peerConnID := <-mgr.OnActivityChan:
if ev.PeerConnID != peerCfg1.PeerConnID { if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID) t.Fatalf("unexpected peerConnID: %v", peerConnID)
}
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
} }
} }

View File

@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case ev := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ev) m.onPeerActivity(peerConnID)
case peerIDs := <-m.inactivityManager.InactivePeersChan(): case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs) m.onPeerInactivityTimedOut(peerIDs)
} }
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
return false return false
} }
func (m *Manager) onPeerActivity(ev activity.Event) { func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[ev.PeerConnID] mp, ok := m.managedPeersByConnID[peerConnID]
if !ok { if !ok {
log.Errorf("peer not found by conn id: %v", ev.PeerConnID) log.Errorf("peer not found by conn id: %v", peerConnID)
return return
} }
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(ev activity.Event) {
m.activateHAGroupPeers(mp.peerCfg) m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) { func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {

View File

@@ -17,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
Address() wgaddr.Address Address() wgaddr.Address
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
MTU() uint16
} }

View File

@@ -1,206 +0,0 @@
package internal
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// mapStateManager is the single read/write point between the management stream
// (writes) and the convergence loop (reads/applies).
//
// The stream calls SetTarget with the latest full SyncResponse — the complete
// desired state. A single background goroutine (run) applies it to the engine in
// bounded passes via apply() until converged, releasing syncMsgMux between passes
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
// coalesces: it keeps converging toward the latest target and the intermediate one
// is SKIPPED — never applied on its own (logged, no onConverged).
//
// Convergence is a single comparison: appliedGen == targetGen. targetGen
// increments on every SetTarget (an internal generation counter, so it also covers
// config-only updates that carry no network-map serial).
//
// onConverged fires once for each — and only each — map that is actually processed
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
// to a real, completed alignment.
type mapStateManager struct {
// apply performs one bounded apply pass and reports whether more passes are needed.
// firstPass is true on the first pass of a given target, so the caller can run
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
// re-runs that only drain the bounded peer batches. The manager owns this signal
// because it owns the convergence boundary; the engine need not track serials for it.
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
// onConverged is called once per processed map, with the elapsed time since that
// map was received (for the sync-duration metric / "sync finished" log).
onConverged func(time.Duration)
// persist snapshots an update to disk for restore-on-restart. Called once per
// update received from management (in SetTarget), including ones later coalesced
// or skipped from apply, so the on-disk state mirrors what management last sent.
// The impl skips config-only updates (nil NetworkMap). May be nil.
persist func(*mgmProto.SyncResponse)
mu sync.Mutex
target *mgmProto.SyncResponse
targetGen uint64
appliedGen uint64
targetSetAt time.Time
wake chan struct{}
}
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
return &mapStateManager{
apply: apply,
persist: persist,
onConverged: onConverged,
wake: make(chan struct{}, 1),
}
}
// SetTarget records the latest update as the desired state and wakes the loop.
// It returns immediately; convergence happens in the background. Serial-based
// staleness of the network map is still enforced inside apply (updateNetworkMap).
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
m.mu.Lock()
// A target that has not settled yet (targetGen > appliedGen) is being superseded
// before it converged: we coalesce to the latest map and never apply this one on
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
if m.target != nil && m.targetGen > m.appliedGen {
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
}
m.target = m.mergeTarget(m.target, update)
// Bump an internal generation counter, NOT the map serial: config-only updates
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
// serial, yet must still be applied. Every SetTarget is therefore a distinct
// target regardless of payload. Map-serial staleness is enforced separately
// inside apply (updateNetworkMap).
m.targetGen++
m.targetSetAt = time.Now()
m.mu.Unlock()
select {
case m.wake <- struct{}{}:
default:
}
// Persist every update received from management — once per update (not per apply
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
// state always reflects the latest map management sent. Done after waking the loop
// so convergence can start in parallel with the disk write. The persist impl skips
// config-only updates (nil NetworkMap).
if m.persist != nil {
m.persist(update)
}
return nil
}
// mergeTarget combines the currently pending target with a freshly received update
// and returns the new desired state. It is called under m.mu from SetTarget and is
// the single seam where the replace-vs-squash decision lives.
//
// Today management always sends a FULL map (the complete desired state), so the
// update simply replaces whatever was pending — prev is ignored. When management
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
// rest of the manager (generation tracking, convergence, signaling) is unaffected
// because it already treats target as "the complete desired state, whatever it is".
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
// Plain replace unless a config-only update (no map) arrives while the pending
// target still has an unconverged map (targetGen > appliedGen). In that case graft
// the pending map (and its checks) onto the incoming config update, so the next pass
// keeps converging the map instead of settling on a nil-map target and stranding the
// remaining peer batches until another full map arrives.
if update == nil || update.GetNetworkMap() != nil || prev == nil || prev.GetNetworkMap() == nil || m.targetGen == m.appliedGen {
return update
}
merged, ok := proto.Clone(update).(*mgmProto.SyncResponse)
if !ok {
return update
}
merged.NetworkMap = prev.GetNetworkMap()
merged.Checks = prev.Checks
return merged
}
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
func (m *mapStateManager) run(ctx context.Context) {
// passGen is the generation of the most recent apply() call (0 = none). A pass is
// the first for its target when its generation differs from the previous one —
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
var passGen uint64
for {
m.mu.Lock()
target, tg, ag := m.target, m.targetGen, m.appliedGen
m.mu.Unlock()
// Fully converged (or nothing yet): block until a new target arrives.
if target == nil || ag == tg {
select {
case <-ctx.Done():
return
case <-m.wake:
continue
}
}
firstPass := tg != passGen
passGen = tg
more, err := m.apply(target, firstPass)
if err != nil {
if ctx.Err() != nil {
return
}
// Log and DROP this target — do not retry it. A deterministic failure
// (e.g. a malformed peer in the map) would otherwise spin every pass
// making no progress. Management is the source of truth and re-delivers
// the full map on the next sync, so dropping is safe; peers already
// applied this convergence stay (idempotent diffs) and the remainder is
// reconciled by the next target. Mirrors the legacy handleSync path,
// where the apply error was logged by the gRPC client and the update
// dropped. No onConverged: this target did not converge.
log.Errorf("apply sync pass, dropping update: %v", err)
m.settle(tg, false)
continue
}
if more {
// keep converging the current target; syncMsgMux was released by apply
// between passes so other subsystems interleave.
continue
}
// This pass converged. Mark applied and signal this one map.
m.settle(tg, true)
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
// ag<tg next iteration -> apply it; this generation was skipped (logged in
// SetTarget) and is not signaled.
}
}
// settle marks generation tg as processed so the loop goes idle instead of
// re-applying the same target. It is a no-op when a newer target arrived during the
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
// just-finished generation was already counted as skipped.
//
// When signal is true (the pass converged) it fires onConverged once for this map;
// when false (the target was dropped on error) it does not — the map did not converge.
func (m *mapStateManager) settle(tg uint64, signal bool) {
m.mu.Lock()
if m.targetGen != tg {
m.mu.Unlock()
return
}
m.appliedGen = tg
setAt := m.targetSetAt
m.mu.Unlock()
if signal && m.onConverged != nil {
m.onConverged(time.Since(setAt))
}
}

View File

@@ -1,270 +0,0 @@
package internal
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// a config-only update arriving while a full map is still converging must keep the
// pending map (so its remaining peer batches still apply); once converged or when the
// pending target has no map, it replaces as usual.
func TestMapStateManager_MergeTargetPreservesPendingMap(t *testing.T) {
m := newMapStateManager(nil, nil, nil)
fullMap := &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
configOnly := &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
// still converging the full map (targetGen > appliedGen): graft the map onto the
// incoming config-only update instead of dropping it
m.targetGen, m.appliedGen = 5, 4
merged := m.mergeTarget(fullMap, configOnly)
require.NotNil(t, merged.GetNetworkMap(), "pending map must be preserved")
require.EqualValues(t, 5, merged.GetNetworkMap().GetSerial())
require.NotNil(t, merged.GetNetbirdConfig(), "new config must be carried")
require.NotSame(t, configOnly, merged, "must not mutate the received update in place")
// already converged (targetGen == appliedGen): nothing pending -> plain replace
m.targetGen, m.appliedGen = 5, 5
require.Same(t, configOnly, m.mergeTarget(fullMap, configOnly))
// a full map always replaces
newFull := &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 6}}
m.targetGen, m.appliedGen = 5, 4
require.Same(t, newFull, m.mergeTarget(fullMap, newFull))
}
// converges over the bounded passes (apply returns more until the 3rd pass),
// fires onConverged exactly once, then blocks (no further apply) until a new target.
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
var passes int32
var firstPasses int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
if firstPass {
atomic.AddInt32(&firstPasses, 1)
}
return n < 3, nil // more on pass 1 and 2, converge on pass 3
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("manager did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
// once converged the loop blocks: no further apply calls
time.Sleep(100 * time.Millisecond)
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
}
// persist runs once per received update (not per apply pass), regardless of how many
// bounded passes that target takes to converge.
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
var passes, persists int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
return n < 3, nil // 3 passes for one target
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
}
// every update received from management is persisted — even one that is coalesced /
// skipped from apply before it ever converges.
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
release := make(chan struct{})
var persists int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
<-release // hold the first apply so the second update coalesces/skips
return false, nil
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
close(release)
// both updates persisted even though map1 is skipped from apply
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
}
// each map that is actually processed (converged before the next arrives) fires
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
converged := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
const maps = 3
for i := 0; i < maps; i++ {
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select { // wait for this map to converge before sending the next (no coalescing)
case <-converged:
case <-time.After(2 * time.Second):
t.Fatalf("map %d not signaled", i)
}
}
// no extra signals once the stream goes quiet
select {
case <-converged:
t.Fatal("unexpected extra onConverged")
case <-time.After(100 * time.Millisecond):
}
}
// a map superseded before it converges is skipped: only the latest (processed) map
// fires onConverged, not the skipped one.
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
release := make(chan struct{})
var applies, converged atomic.Int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applies.Add(1)
<-release // hold the first apply in-flight so we can queue a newer target
return false, nil
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// map1 is picked up; its apply blocks on release
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
// map2 supersedes map1 before it settled -> map1 is skipped
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
close(release) // let both applies proceed
// only the processed (latest) map signals; the skipped one does not
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
time.Sleep(150 * time.Millisecond)
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
}
// an apply error drops the target: no retry of the same target, no onConverged,
// the loop goes idle — and a fresh target is still applied afterwards.
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
applied := make(chan struct{}, 8)
var failNext atomic.Bool
failNext.Store(true)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
if failNext.Load() {
return false, errors.New("boom")
}
return false, nil // converge in one pass
}
var converged atomic.Int32
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// first target errors -> applied once, then dropped (no retry, no onConverged)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("errored target not applied")
}
select {
case <-applied:
t.Fatal("errored target must not be retried")
case <-time.After(150 * time.Millisecond):
}
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
// a new target is still processed normally and converges
failNext.Store(false)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target after error not applied")
}
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
}
// a new target after convergence triggers a fresh apply; an idle (converged)
// manager does not apply on its own.
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
applied := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("first target not applied")
}
// converged → must stay idle (no spurious apply)
select {
case <-applied:
t.Fatal("unexpected apply while idle/converged")
case <-time.After(150 * time.Millisecond):
}
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target not applied")
}
}

View File

@@ -120,30 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked() m.trimLocked()
} }
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) { func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success" result := "success"
if !success { if !success {

View File

@@ -78,25 +78,6 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.) - `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.) - `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration ### Login Duration
Measurement: `netbird_login` Measurement: `netbird_login`
@@ -211,51 +192,3 @@ docker compose exec influxdb influx query \
# Check ingest server health # Check ingest server health
curl http://localhost:8087/health curl http://localhost:8087/health
``` ```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.

View File

@@ -1,259 +0,0 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087" defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns" defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this maxTagValueLength = 64 // reject tag values longer than this
) )
@@ -59,19 +59,6 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true, "peer_id": true,
}, },
}, },
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": { "netbird_login": {
allowedFields: map[string]bool{ allowedFields: map[string]bool{
"duration_seconds": true, "duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
} }
func TestValidateLine_DurationTooLarge(t *testing.T) { func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890` line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
err := validateLine(line) err := validateLine(line)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "too large") assert.Contains(t, err.Error(), "too large")
} }
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) { func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890` line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
err := validateLine(line) err := validateLine(line)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "too large") assert.Contains(t, err.Error(), "too large")

View File

@@ -56,9 +56,6 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message // RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration) RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took // RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool) RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
@@ -130,18 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration) c.impl.RecordSyncDuration(ctx, agentInfo, duration)
} }
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took // RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) { func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil { if c == nil {

View File

@@ -70,9 +70,6 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) { func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
} }
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) { func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
} }

View File

@@ -6,7 +6,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"slices"
"sync" "sync"
"time" "time"
@@ -137,39 +136,6 @@ type Conn struct {
// Connection stage timestamps for metrics // Connection stage timestamps for metrics
metricsRecorder MetricsRecorder metricsRecorder MetricsRecorder
metricsStages *MetricsStages metricsStages *MetricsStages
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
// transport is up.
pendingFirstPacket []byte
}
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
pkt := conn.pendingFirstPacket
if len(pkt) == 0 {
return
}
switch {
case proxy != nil:
if err := proxy.InjectPacket(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
return
}
case directConn != nil:
if _, err := directConn.Write(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
return
}
default:
conn.Log.Debugf("no transport available to reinject captured first packet")
return
}
conn.pendingFirstPacket = nil
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@@ -206,16 +172,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open(engineCtx context.Context) error { func (conn *Conn) Open(engineCtx context.Context) error {
return conn.open(engineCtx, nil)
}
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
// the real transport is established. The packet is retained only on a successful open.
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
return conn.open(engineCtx, firstPacket)
}
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -271,9 +227,6 @@ func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
defer conn.wg.Done() defer conn.wg.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent) conn.guard.Start(conn.ctx, conn.onGuardEvent)
}() }()
if len(firstPacket) > 0 {
conn.pendingFirstPacket = slices.Clone(firstPacket)
}
conn.opened = true conn.opened = true
return nil return nil
} }
@@ -470,8 +423,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep) conn.wgProxyRelay.RedirectAs(ep)
} }
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo, updateTime) conn.updateIceState(iceConnInfo, updateTime)
@@ -595,8 +546,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.injectPendingFirstPacket(wgProxy, nil)
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected() conn.statusRelay.SetConnected()

View File

@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
} }
offer := h.buildOfferAnswer() offer := h.buildOfferAnswer()
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString()) h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
return h.signaler.SignalOffer(offer, h.config.Key) return h.signaler.SignalOffer(offer, h.config.Key)
} }
func (h *Handshaker) sendAnswer() error { func (h *Handshaker) sendAnswer() error {
answer := h.buildOfferAnswer() answer := h.buildOfferAnswer()
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString()) h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key) return h.signaler.SignalAnswer(answer, h.config.Key)
} }

View File

@@ -192,7 +192,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
// Pure read methods take RLock; anything that mutates state takes Lock. // Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct { type Status struct {
mux sync.RWMutex mux sync.RWMutex
muxRelays sync.RWMutex
peers map[string]State peers map[string]State
ipToKey map[string]string ipToKey map[string]string
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
@@ -245,8 +244,8 @@ func NewRecorder(mgmAddress string) *Status {
} }
func (d *Status) SetRelayMgr(manager *relayClient.Manager) { func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.muxRelays.Lock() d.mux.Lock()
defer d.muxRelays.Unlock() defer d.mux.Unlock()
d.relayMgr = manager d.relayMgr = manager
} }
@@ -907,8 +906,8 @@ func (d *Status) MarkSignalConnected() {
} }
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.muxRelays.Lock() d.mux.Lock()
defer d.muxRelays.Unlock() defer d.mux.Unlock()
d.relayStates = relayResults d.relayStates = relayResults
} }
@@ -1019,26 +1018,21 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states // GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetRelayStates() []relay.ProbeResult {
d.muxRelays.RLock() d.mux.RLock()
defer d.mux.RUnlock()
if d.relayMgr == nil { if d.relayMgr == nil {
defer d.muxRelays.RUnlock() return d.relayStates
return slices.Clone(d.relayStates)
} }
relayMgr := d.relayMgr // extend the list of stun, turn servers with relay address
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates) relayStates := slices.Clone(d.relayStates)
d.muxRelays.RUnlock()
states := relayMgr.RelayStates() // if the server connection is not established then we will use the general address
if len(states) == 0 { // in case of connection we will use the instance specific address
// no relay connection tracked yet; surface configured servers as instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
// unavailable with the real reconnect error when known if err != nil {
err := relayClient.ErrRelayClientNotConnected // TODO add their status
if connErr := relayMgr.RelayConnectError(); connErr != nil { for _, r := range d.relayMgr.ServerURLs() {
err = connErr
}
for _, r := range relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{ relayStates = append(relayStates, relay.ProbeResult{
URI: r, URI: r,
Err: err, Err: err,
@@ -1047,14 +1041,10 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates return relayStates
} }
for _, rs := range states { relayState := relay.ProbeResult{
relayStates = append(relayStates, relay.ProbeResult{ URI: instanceAddr,
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
} }
return relayStates return append(relayStates, relayState)
} }
func (d *Status) ForwardingRules() []firewall.ForwardRule { func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1415,7 +1405,6 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{ pbRelayState := &proto.RelayState{
URI: relayState.URI, URI: relayState.URI,
Available: relayState.Err == nil, Available: relayState.Err == nil,
Transport: relayState.Transport,
} }
if err := relayState.Err; err != nil { if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error() pbRelayState.Error = err.Error()

View File

@@ -88,24 +88,11 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
if !ok { if !ok {
return return
} }
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil { if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err) p.Log.Errorf("failed to open peer connection: %v", err)
} }
}
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
// reinjected once the real transport is established.
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
} }
func (s *Store) PeerConnIdle(pubKey string) { func (s *Store) PeerConnIdle(pubKey string) {

View File

@@ -108,10 +108,6 @@ type ConfigInput struct {
// Config Configuration type // Config Configuration type
type Config struct { type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer // Wireguard private key of local peer
PrivateKey string PrivateKey string
PreSharedKey string PreSharedKey string
@@ -274,16 +270,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
} }
func (config *Config) apply(input ConfigInput) (updated bool, err error) { func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.Name != "" {
sanitized, err := sanitizeDisplayName(config.Name)
if err != nil {
return false, fmt.Errorf("invalid profile name: %w", err)
}
if sanitized != config.Name {
config.Name = sanitized
updated = true
}
}
if config.ManagementURL == nil { if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL) log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
@@ -433,7 +419,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) { if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed { if *input.ServerSSHAllowed {
log.Infof("enabling SSH server") log.Infof("enabling SSH server")
} else { } else {

View File

@@ -242,35 +242,6 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
} }
} }
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
// Configs written before ServerSSHAllowed was introduced lack the field and
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
// apply the value instead of panicking on a nil pointer dereference.
tests := []struct {
name string
input *bool
want bool
}{
{"enable", util.True(), true},
{"disable", util.False(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
config, err := UpdateConfig(ConfigInput{
ConfigPath: configPath,
ServerSSHAllowed: tt.input,
})
require.NoError(t, err)
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) { func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) { newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {

View File

@@ -1,118 +0,0 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
type ID string
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (ID, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return ID(hex.EncodeToString(buf)), nil
}
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(id ID) bool {
s := id.String()
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func (id ID) ShortID() string {
if id == DefaultProfileName {
return DefaultProfileName
}
runes := []rune(id)
if len(runes) <= shortIDLen {
return id.String()
}
return string(runes[:shortIDLen])
}
func (id ID) String() string {
return string(id)
}

View File

@@ -19,41 +19,19 @@ const (
) )
type Profile struct { type Profile struct {
// ID is the on-disk filename stem (without .json). For new profiles Name string
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID ID
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
IsActive bool IsActive bool
} }
func (p *Profile) FilePath() (string, error) { func (p *Profile) FilePath() (string, error) {
if p.Path != "" { if p.Name == "" {
return p.Path, nil return "", fmt.Errorf("active profile name is empty")
} }
id := p.ID if p.Name == defaultProfileName {
if id == "" {
id = ID(p.Name)
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
return DefaultConfigPath, nil return DefaultConfigPath, nil
} }
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current() username, err := user.Current()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err) return "", fmt.Errorf("failed to get current user: %w", err)
@@ -64,13 +42,10 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err) return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
} }
return filepath.Join(configDir, id.String()+".json"), nil return filepath.Join(configDir, p.Name+".json"), nil
} }
func (p *Profile) IsDefault() bool { func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName return p.Name == defaultProfileName
} }
@@ -82,24 +57,18 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{} return &ProfileManager{}
} }
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) { func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
id := pm.getActiveProfileState() prof := pm.getActiveProfileState()
return &Profile{ID: id}, nil return &Profile{Name: prof}, nil
} }
// SwitchProfile records the given profile ID as active in the local user func (pm *ProfileManager) SwitchProfile(profileName string) error {
// state file. profileName = sanitizeProfileName(profileName)
func (pm *ProfileManager) SwitchProfile(id ID) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
if err := pm.setActiveProfileState(id); err != nil { if err := pm.setActiveProfileState(profileName); err != nil {
return fmt.Errorf("failed to switch profile: %w", err) return fmt.Errorf("failed to switch profile: %w", err)
} }
return nil return nil
@@ -116,7 +85,7 @@ func sanitizeProfileName(name string) string {
}, name) }, name)
} }
func (pm *ProfileManager) getActiveProfileState() ID { func (pm *ProfileManager) getActiveProfileState() string {
configDir, err := getConfigDir() configDir, err := getConfigDir()
if err != nil { if err != nil {
@@ -144,10 +113,10 @@ func (pm *ProfileManager) getActiveProfileState() ID {
return defaultProfileName return defaultProfileName
} }
return ID(profileName) return profileName
} }
func (pm *ProfileManager) setActiveProfileState(id ID) error { func (pm *ProfileManager) setActiveProfileState(profileName string) error {
configDir, err := getConfigDir() configDir, err := getConfigDir()
if err != nil { if err != nil {
@@ -156,7 +125,7 @@ func (pm *ProfileManager) setActiveProfileState(id ID) error {
statePath := filepath.Join(configDir, activeProfileStateFilename) statePath := filepath.Join(configDir, activeProfileStateFilename)
err = os.WriteFile(statePath, []byte(id), 0600) err = os.WriteFile(statePath, []byte(profileName), 0600)
if err != nil { if err != nil {
return fmt.Errorf("failed to write active profile state: %w", err) return fmt.Errorf("failed to write active profile state: %w", err)
} }
@@ -173,7 +142,7 @@ func GetLoginHint() string {
return "" return ""
} }
profileState, err := pm.GetProfileState(activeProf.ID) profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil { if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err) log.Debugf("failed to get profile state for login hint: %v", err)
return "" return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState() state, err := sm.GetActiveProfileState()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
err = sm.SetActiveProfileStateToDefault() err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err) assert.NoError(t, err)
active, err := sm.GetActiveProfileState() active, err := sm.GetActiveProfileState()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "default", active.ID.String()) assert.Equal(t, "default", active.Name)
}) })
}) })
} }
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current() currUser, err := user.Current()
assert.NoError(t, err) assert.NoError(t, err)
sm := &ServiceManager{} sm := &ServiceManager{}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username} state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state) err = sm.SetActiveProfileState(state)
assert.NoError(t, err) assert.NoError(t, err)
// Should error on nil or incomplete state // Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil) err = sm.SetActiveProfileState(nil)
assert.Error(t, err) assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""}) err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
assert.Error(t, err) assert.Error(t, err)
}) })
}) })

View File

@@ -2,7 +2,6 @@ package profilemanager
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -24,43 +23,12 @@ var (
DefaultConfigPathDir = "" DefaultConfigPathDir = ""
DefaultConfigPath = "" DefaultConfigPath = ""
ActiveProfileStatePath = "" ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found") ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
) )
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() { func init() {
DefaultConfigPathDir = "/var/lib/netbird/" DefaultConfigPathDir = "/var/lib/netbird/"
@@ -86,34 +54,25 @@ func init() {
} }
type ActiveProfileState struct { type ActiveProfileState struct {
// ID is the on-disk filename stem of the active profile. The JSON tag stays Name string `json:"name"`
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID ID `json:"name"`
Username string `json:"username"` Username string `json:"username"`
} }
func (a *ActiveProfileState) FilePath() (string, error) { func (a *ActiveProfileState) FilePath() (string, error) {
if a.ID == "" { if a.Name == "" {
return "", fmt.Errorf("active profile ID is empty") return "", fmt.Errorf("active profile name is empty")
} }
if a.ID == defaultProfileName { if a.Name == defaultProfileName {
return DefaultConfigPath, nil return DefaultConfigPath, nil
} }
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username) configDir, err := getConfigDirForUser(a.Username)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err) return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
} }
return filepath.Join(configDir, a.ID.String()+".json"), nil return filepath.Join(configDir, a.Name+".json"), nil
} }
type ServiceManager struct { type ServiceManager struct {
@@ -219,7 +178,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err) return nil, fmt.Errorf("failed to set active profile to default: %w", err)
} }
return &ActiveProfileState{ return &ActiveProfileState{
ID: defaultProfileName, Name: "default",
Username: "", Username: "",
}, nil }, nil
} else { } else {
@@ -227,12 +186,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
} }
} }
if activeProfile.ID == "" { if activeProfile.Name == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil { if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err) return nil, fmt.Errorf("failed to set active profile to default: %w", err)
} }
return &ActiveProfileState{ return &ActiveProfileState{
ID: defaultProfileName, Name: "default",
Username: "", Username: "",
}, nil }, nil
} }
@@ -257,29 +216,25 @@ func (s *ServiceManager) setDefaultActiveState() error {
} }
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error { func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.ID == "" { if a == nil || a.Name == "" {
return errors.New("invalid active profile state") return errors.New("invalid active profile state")
} }
if a.ID != defaultProfileName && a.Username == "" { if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID) return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
} }
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil { if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err) return fmt.Errorf("failed to write active profile state: %w", err)
} }
log.Infof("active profile set to %s for %s", a.ID, a.Username) log.Infof("active profile set to %s for %s", a.Name, a.Username)
return nil return nil
} }
func (s *ServiceManager) SetActiveProfileStateToDefault() error { func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{ return s.SetActiveProfileState(&ActiveProfileState{
ID: defaultProfileName, Name: "default",
Username: "", Username: "",
}) })
} }
@@ -288,117 +243,57 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath return DefaultConfigPath
} }
// AddProfile creates a new profile with a generated ID. The user-supplied func (s *ServiceManager) AddProfile(profileName, username string) error {
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
configDir, err := s.getConfigDir(username) configDir, err := s.getConfigDir(username)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err) return fmt.Errorf("failed to get config directory: %w", err)
} }
displayName, err = sanitizeDisplayName(displayName) profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid profile name: %w", err) return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
} }
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id.String()+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath}) cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create new config: %w", err) return fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
} }
return &Profile{ err = util.WriteJson(context.Background(), profPath, cfg)
ID: id,
Name: displayName,
Path: profPath,
}, nil
}
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
displayName, err := sanitizeDisplayName(newName)
if err != nil { if err != nil {
return fmt.Errorf("invalid profile name: %w", err) return fmt.Errorf("failed to write profile config: %w", err)
} }
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
data, err := os.ReadFile(target.Path)
if err != nil {
return err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
return fmt.Errorf("failed to write profile name: %w", err)
}
return nil return nil
} }
// RemoveProfile deletes the profile identified by id. Callers must have func (s *ServiceManager) RemoveProfile(profileName, username string) error {
// already resolved any user-supplied handle to a concrete ID via configDir, err := s.getConfigDir(username)
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if id == defaultProfileName {
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
if err != nil { if err != nil {
return fmt.Errorf("load profiles: %w", err) return fmt.Errorf("failed to get config directory: %w", err)
} }
var target *Profile profileName = sanitizeProfileName(profileName)
for i := range profiles {
if profiles[i].ID == id { if profileName == defaultProfileName {
target = &profiles[i] return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
break
}
} }
if target == nil { profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound return ErrProfileNotFound
} }
@@ -406,26 +301,57 @@ func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) { if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err) return fmt.Errorf("failed to get active profile: %w", err)
} }
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id) if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
} }
if err := util.RemoveJson(target.Path); err != nil { err = util.RemoveJson(profPath)
if err != nil {
return fmt.Errorf("failed to remove profile config: %w", err) return fmt.Errorf("failed to remove profile config: %w", err)
} }
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil return nil
} }
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) { func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
return s.loadAllProfiles(username) configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
} }
// GetStatePath returns the path to the state file based on the operating system // GetStatePath returns the path to the state file based on the operating system
@@ -443,12 +369,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath return defaultStatePath
} }
if activeProf.ID == defaultProfileName { if activeProf.Name == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
return defaultStatePath return defaultStatePath
} }
@@ -458,7 +379,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath return defaultStatePath
} }
return filepath.Join(configDir, activeProf.ID.String()+".state.json") return filepath.Join(configDir, activeProf.Name+".state.json")
} }
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser // getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -469,169 +390,3 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username) return getConfigDirForUser(username)
} }
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := ID(strings.TrimSuffix(base, ".json"))
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(ID(stem)) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem.String()
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == ID(activeID),
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (ID, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique exact name, then unique ID
// prefix. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == ID(handle) {
return &profiles[i], nil
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID.String(), handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -1,230 +0,0 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID.String())
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID.String())
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(ID(tc.in))
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,20 +13,13 @@ type ProfileState struct {
Email string `json:"email"` Email string `json:"email"`
} }
// GetProfileState reads the per-profile state file keyed by profile ID. func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
configDir, err := getConfigDir() configDir, err := getConfigDir()
if err != nil { if err != nil {
return nil, fmt.Errorf("get config directory: %w", err) return nil, fmt.Errorf("get config directory: %w", err)
} }
if id != defaultProfileName && !IsValidProfileFilenameStem(id) { stateFile := filepath.Join(configDir, profileName+".state.json")
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFileExists, err := fileExists(stateFile) stateFileExists, err := fileExists(stateFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err) return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -58,12 +51,7 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err) return fmt.Errorf("get active profile: %w", err)
} }
id := activeProf.ID stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state) err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil { if err != nil {
return fmt.Errorf("write profile state: %w", err) return fmt.Errorf("write profile state: %w", err)

View File

@@ -32,9 +32,6 @@ type ProbeResult struct {
URI string URI string
Err error Err error
Addr string Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
} }
type StunTurnProbe struct { type StunTurnProbe struct {

View File

@@ -226,11 +226,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
// All query types for an intercepted domain are forwarded to the peer's // pass if non A/AAAA query
// DNS forwarder, which owns the name. Falling through to the system if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
// resolver would let it answer NXDOMAIN for a name it isn't authoritative d.continueToNextHandler(w, r, logger, "non A/AAAA query")
// for, poisoning the whole name (including the A/AAAA records the route return
// does serve). The forwarder answers NODATA for types it cannot resolve. }
d.mu.RLock() d.mu.RLock()
peerKey := d.currentPeerKey peerKey := d.currentPeerKey
d.mu.RUnlock() d.mu.RUnlock()
@@ -250,14 +251,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
// describing why a lookup failed. The OPT is stripped from the reply when
// the original client did not request EDNS0.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(dns.DefaultMsgSize, false)
}
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10)) upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()
@@ -267,13 +260,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if ede, ok := resutil.ExtractEDE(reply); ok {
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
}
if !hadEdns {
resutil.StripOPT(reply)
}
resutil.SetMeta(w, "peer", peerKey) resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id reply.Id = r.Id
@@ -292,6 +278,19 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
} }
} }
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) { func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists { if !exists {

View File

@@ -333,8 +333,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
} }
m.notifier.Close()
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.clientRoutes = nil m.clientRoutes = nil

View File

@@ -1,5 +1,3 @@
//go:build privileged
package routemanager package routemanager
import ( import (

View File

@@ -16,7 +16,7 @@ import (
type Notifier struct { type Notifier struct {
initialRoutes []*route.Route initialRoutes []*route.Route
currentRoutes []*route.Route currentRoutes []*route.Route
fakeIPRoutes []*route.Route fakeIPRoutes []*route.Route
listener listener.NetworkChangeListener listener listener.NetworkChangeListener
listenerMux sync.Mutex listenerMux sync.Mutex
@@ -119,7 +119,3 @@ func (n *Notifier) GetInitialRouteRanges() []string {
sort.Strings(initialStrings) sort.Strings(initialStrings)
return initialStrings return initialStrings
} }
func (n *Notifier) Close() {
// unused
}

View File

@@ -3,7 +3,6 @@
package notifier package notifier
import ( import (
"container/list"
"net/netip" "net/netip"
"slices" "slices"
"sort" "sort"
@@ -15,26 +14,19 @@ import (
) )
type Notifier struct { type Notifier struct {
mu sync.Mutex
cond *sync.Cond
currentPrefixes []string currentPrefixes []string
listener listener.NetworkChangeListener
queue *list.List listener listener.NetworkChangeListener
closed bool listenerMux sync.Mutex
} }
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
n := &Notifier{ return &Notifier{}
queue: list.New(),
}
n.cond = sync.NewCond(&n.mu)
go n.deliverLoop()
return n
} }
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.mu.Lock() n.listenerMux.Lock()
defer n.mu.Unlock() defer n.listenerMux.Unlock()
n.listener = listener n.listener = listener
} }
@@ -51,52 +43,32 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
} }
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0, len(prefixes)) newNets := make([]string, 0)
for _, prefix := range prefixes { for _, prefix := range prefixes {
newNets = append(newNets, prefix.String()) newNets = append(newNets, prefix.String())
} }
sort.Strings(newNets) sort.Strings(newNets)
n.mu.Lock()
if slices.Equal(n.currentPrefixes, newNets) { if slices.Equal(n.currentPrefixes, newNets) {
n.mu.Unlock()
return return
} }
n.currentPrefixes = newNets
routes := strings.Join(n.currentPrefixes, ",")
n.queue.PushBack(routes)
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) Close() { n.currentPrefixes = newNets
n.mu.Lock() n.notify()
n.closed = true }
n.cond.Signal() func (n *Notifier) notify() {
n.mu.Unlock() n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
}(n.listener)
} }
func (n *Notifier) GetInitialRouteRanges() []string { func (n *Notifier) GetInitialRouteRanges() []string {
return nil return nil
} }
func (n *Notifier) deliverLoop() {
for {
n.mu.Lock()
for n.queue.Len() == 0 && !n.closed {
n.cond.Wait()
}
if n.closed && n.queue.Len() == 0 {
n.mu.Unlock()
return
}
routes := n.queue.Remove(n.queue.Front()).(string)
l := n.listener
n.mu.Unlock()
if l != nil {
l.OnNetworkChanged(routes)
}
}
}

View File

@@ -38,7 +38,3 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
func (n *Notifier) GetInitialRouteRanges() []string { func (n *Notifier) GetInitialRouteRanges() []string {
return []string{} return []string{}
} }
func (n *Notifier) Close() {
// unused
}

View File

@@ -1,69 +0,0 @@
//go:build linux && !android
package systemops
import (
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}

View File

@@ -1,191 +0,0 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

Some files were not shown because too many files have changed in this diff Show More