Compare commits

..

1 Commits

Author SHA1 Message Date
mlsmaycon
5eb28acb11 [management] Account-scoped ephemeral peer cleanup
Replace the per-peer linked list with a per-account map keyed by
accountID. Each entry holds only the latest disconnect timestamp we
have observed for that account and a single timer that fires the next
sweep. Sweeps query the database for the authoritative stale set,
batch the deletes through peers.Manager.DeletePeers, then drop the
account from the tracker when lastDisc + lifeTime <= now (else
re-arm at horizon + cleanupWindow).

The drop rule is the entire termination story: an account stays
tracked only while OnPeerDisconnected keeps refreshing the
timestamp. There is no internal feedback loop that can advance
lastDisc on its own, so once disconnects stop the account drops in
at most one sweep.

A timestamp beats the ref-counter alternative because the counter
drifts positive in three real situations the cleanup loop has no
signal for: peers deleted via the API while offline, peers that
reconnect within the lifetime window, and management restarts. The
timestamp design never claims to know the size of the stale set —
it only knows the latest disconnect we observed and uses that to
bound when it is safe to drop the account.

OnPeerConnected becomes a no-op. The sweep query already filters
reconnected peers at the database level (peer_status_connected =
false in the WHERE clause), so there is nothing the in-memory
tracker needs to do on reconnect. The interface method is preserved
for call-site compatibility.

LoadInitialPeers no longer runs the catch-up query synchronously.
It schedules a deferred load via time.AfterFunc at a random delay
between 8 and 10 minutes. Without the jitter, every management
replica in a fleet-wide deploy would issue the catch-up query
simultaneously. The catch-up itself is one GROUP BY against the
peers table:
```sql
  SELECT account_id, MAX(peer_status_last_seen)
  FROM peers
  WHERE ephemeral = true AND peer_status_connected = false
  GROUP BY account_id
```
For each row the tracker seeds an entry and arms a sweep at
max(now, last_seen + lifeTime) + cleanupWindow — so accounts whose
backlog is already stale get cleaned soon after the delay elapses,
and accounts that disconnected recently wait the remaining window.
OnPeerDisconnected calls that arrive during the delay window seed
the tracker live, and the catch-up query skips accounts that are
already tracked.

Stop() cancels both the deferred initial-load timer and every
per-account sweep timer, and flips a stopped flag so subsequent
OnPeerDisconnected calls are ignored. This makes restarts and test
teardown clean.

Two new store methods:
  GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
  GetEphemeralAccountsLastDisconnect(ctx)
Both are scoped, indexable queries that the existing peers table
supports without schema changes.

The pending metric is renamed from
management.ephemeral.peers.pending to
management.ephemeral.accounts.tracked to reflect the new semantics
(it now counts accounts on the cleanup list, not peers). Method
names on the metrics type are unchanged so no production call site
has to move. No new metric labels, no per-account cardinality.

The algorithm was validated against an in-memory SQLite peers
table through an 11-scenario prototype kept under proto/, including
pathological-churn and 4-hour randomized simulations. All scenarios
terminate; max observed per-account sweep rate stays bounded near
the lifeTime + cleanupWindow cadence even under sustained
disconnect churn.

Verification: go build, go vet, race-clean tests across the
ephemeral, store, and telemetry packages, plus a clean
golangci-lint pass on the touched packages.
2026-05-19 09:50:14 +02:00
217 changed files with 3775 additions and 14734 deletions

View File

@@ -1,45 +0,0 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
actions:
patterns:
- "*"
ignore:
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
- dependency-name: "git-town/action"
update-types:
- "version-update:semver-minor"
- "version-update:semver-major"
- package-ecosystem: "gomod"
directories:
- "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
aws-sdk:
patterns:
- "github.com/aws/aws-sdk-go-v2/*"
pion:
patterns:
- "github.com/pion/*"
gorm:
patterns:
- "gorm.io/*"
otel:
patterns:
- "go.opentelemetry.io/*"
testcontainers:
patterns:
- "github.com/testcontainers/testcontainers-go/*"
wireguard:
patterns:
- "golang.zx2c4.com/wireguard*"

View File

@@ -12,7 +12,6 @@
- [ ] Is a feature enhancement
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -2,16 +2,16 @@ name: Check License Dependencies
on:
push:
branches: [main]
branches: [ main ]
paths:
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request:
paths:
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs:
check-internal-dependencies:
@@ -19,10 +19,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
@@ -59,57 +56,55 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: true
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
done <<< "$COPYLEFT_DEPS"
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -83,7 +83,7 @@ jobs:
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}

View File

@@ -8,10 +8,11 @@ jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
- uses: roots/discourse-topic-github-release-action@main
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags: releases
discourse-tags:
releases

View File

@@ -3,7 +3,7 @@ name: Git Town
on:
pull_request:
branches:
- "**"
- '**'
jobs:
git-town:
@@ -15,9 +15,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
- uses: actions/checkout@v4
- uses: git-town/action@v1.2.1
with:
skip-single-stacks: true

View File

@@ -16,18 +16,16 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -46,3 +44,4 @@ jobs:
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -15,31 +15,20 @@ jobs:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error

View File

@@ -18,11 +18,9 @@ jobs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
@@ -30,7 +28,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -38,10 +36,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
id: cache
with:
path: |
@@ -115,16 +113,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -132,10 +128,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -162,16 +158,14 @@ jobs:
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -183,7 +177,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
@@ -237,12 +231,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -254,10 +246,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -285,16 +277,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -308,7 +298,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -334,16 +324,14 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: ["386", "amd64"]
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -355,10 +343,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -382,21 +370,19 @@ jobs:
test_management:
name: "Management / Unit"
needs: [build-cache]
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres", "mysql"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -404,10 +390,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -424,7 +410,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -441,7 +427,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
@@ -451,13 +437,13 @@ jobs:
benchmark:
name: "Management / Benchmark"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -488,12 +474,10 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -501,10 +485,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -521,7 +505,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -545,13 +529,13 @@ jobs:
api_benchmark:
name: "Management / Benchmark (API)"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -582,12 +566,10 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -595,10 +577,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
@@ -615,7 +597,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -641,22 +623,20 @@ jobs:
api_integration_test:
name: "Management / Integration"
needs: [build-cache]
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: ["amd64"]
store: ["sqlite", "postgres"]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -664,10 +644,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}

View File

@@ -18,12 +18,10 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
id: go
with:
go-version-file: "go.mod"
@@ -35,7 +33,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
@@ -46,15 +44,16 @@ jobs:
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompressing wintun files
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'

View File

@@ -15,11 +15,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
@@ -40,15 +38,13 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
@@ -56,7 +52,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
skip-cache: true

View File

@@ -22,9 +22,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: run install script
env:

View File

@@ -16,25 +16,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -54,11 +52,9 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;

View File

@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
@@ -20,66 +20,34 @@ jobs:
per_page: 100,
});
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
async function getVersionHeader(path, ref) {
try {
const res = await github.rest.repos.getContent({
owner: context.repo.owner,
repo: context.repo.repo,
path,
ref,
});
if (!res.data.content) {
return { ok: false, reason: 'no inline content (file too large)' };
}
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
const lines = content
.split('\n')
.slice(0, 20)
.filter(line => versionPattern.test(line));
return { ok: true, lines };
} catch (e) {
return { ok: false, reason: e.message };
}
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
base: base.lines,
head: head.lines,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n` +
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.5"
SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -24,9 +24,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
@@ -53,26 +51,19 @@ jobs:
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
prepare: |
# Install required packages
pkg install -y git curl portlint
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
@@ -102,19 +93,19 @@ jobs:
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
@@ -133,25 +124,26 @@ jobs:
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -164,18 +156,18 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
uses: docker/setup-buildx-action@v2
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -199,7 +191,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -290,28 +282,28 @@ jobs:
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -322,26 +314,27 @@ jobs:
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -382,7 +375,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -411,7 +404,7 @@ jobs:
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -425,17 +418,16 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -449,7 +441,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -457,7 +449,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
@@ -482,26 +474,27 @@ jobs:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
@@ -521,27 +514,29 @@ jobs:
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
destination: ${{ env.downloadPath }}\mesa3d.7z
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
@@ -552,38 +547,35 @@ jobs:
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
uses: carlosperate/download-file-action@v2
with:
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
destination: ${{ github.workspace }}\envar_plugin.zip
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
shell: pwsh
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
run: |
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
Copy-Item -Destination $nsisPluginDir -Force
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
@@ -600,7 +592,7 @@ jobs:
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
@@ -619,7 +611,7 @@ jobs:
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
@@ -711,7 +703,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign bin and installer
repo: netbirdio/sign-pipelines

View File

@@ -14,9 +14,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -3,7 +3,7 @@ name: sync tag
on:
push:
tags:
- "v*"
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
@@ -29,7 +29,7 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
@@ -42,10 +42,10 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -6,10 +6,10 @@ on:
- main
pull_request:
paths:
- "infrastructure_files/**"
- ".github/workflows/test-infrastructure-files.yml"
- "management/cmd/**"
- "signal/cmd/**"
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
store: ["sqlite", "postgres", "mysql"]
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -68,17 +68,15 @@ jobs:
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -141,8 +139,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
@@ -256,9 +254,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -3,9 +3,9 @@ name: update docs
on:
push:
tags:
- "v*"
- 'v*'
paths:
- "shared/management/http/api/openapi.yml"
- 'shared/management/http/api/openapi.yml'
jobs:
trigger_docs_api_update:
@@ -13,10 +13,10 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
uses: benc-uk/workflow-dispatch@v1
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -19,17 +19,15 @@ jobs:
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
install-mode: binary
@@ -44,11 +42,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
- name: Build Wasm client
@@ -69,3 +65,4 @@ jobs:
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -15,7 +15,6 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
- [Contributing to NetBird](#contributing-to-netbird)
- [Contents](#contents)
- [Code of conduct](#code-of-conduct)
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
- [Directory structure](#directory-structure)
- [Development setup](#development-setup)
- [Requirements](#requirements)
@@ -34,14 +33,6 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
By participating, you are expected to uphold this code. Please report
unacceptable behavior to community@netbird.io.
## Discuss changes with the NetBird team first
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
## Directory structure
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.

153
README.md
View File

@@ -1,134 +1,147 @@
<div align="center">
<p align="center">
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
</p>
<p align="center">
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
</a>
<br/>
<br/>
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
<p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
</a>
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
</a>
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
</a>
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
<p align="center">
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
</strong>
<strong>
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>
<br>
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open Source Network Security in a Single Platform
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-host NetBird (video)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features
| Connectivity | Management | Security | Automation | Platforms |
|---|---|---|---|---|
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
| Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
| [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
| | | | | ✓ OpenWRT |
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
| Connectivity | Management | Security | Automation| Platforms |
|----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
- Check the NetBird [admin UI](https://app.netbird.io/).
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
- Check NetBird [admin UI](https://app.netbird.io/).
- Add more machines.
### Quickstart with self-hosted NetBird
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
**Infrastructure requirements:**
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
- A **public domain** name pointing to the VM.
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
### A bit on NetBird internals
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
### Support acknowledgement
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Acknowledgements
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
### Legal
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -11,7 +11,7 @@ import (
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Fatal(err)
}
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)

View File

@@ -12,7 +12,6 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -85,12 +84,6 @@ type Options struct {
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// BlockLANAccess blocks the embedded peer from reaching the host's
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
// when the embedded client must never act as a stepping stone into
// the host's local network (e.g. the proxy's overlay peer).
BlockLANAccess bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
@@ -101,26 +94,6 @@ type Options struct {
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
// Performance configures the tunnel's buffer pool cap and batch size.
Performance Performance
}
// Performance configures the embedded client's tunnel memory/throughput knobs.
//
// These settings are process-global: any non-nil field also becomes the
// default for Clients constructed by later embed.New calls in the same
// process. Nil fields are ignored.
type Performance struct {
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
// leaves the pool unbounded. Lower values trade throughput for a
// tighter memory ceiling. May also be changed on a running Client via
// Client.SetPerformance, provided this field was nonzero at construction.
PreallocatedBuffersPerPool *uint32
// MaxBatchSize overrides the number of packets the tunnel reads or
// writes per syscall, which also bounds eager buffer allocation per
// worker. Zero uses the platform default. Applied at construction
// only; ignored by Client.SetPerformance.
MaxBatchSize *uint32
}
// validateCredentials checks that exactly one credential type is provided
@@ -202,7 +175,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound,
BlockLANAccess: &opts.BlockLANAccess,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
@@ -220,13 +192,6 @@ func New(opts Options) (*Client, error) {
config.PrivateKey = opts.PrivateKey
}
if opts.Performance.PreallocatedBuffersPerPool != nil {
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
}
if opts.Performance.MaxBatchSize != nil {
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
@@ -440,21 +405,6 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
}, nil
}
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false
}
state, found := c.recorder.PeerStateByIP(ip.String())
if !found {
return "", "", false
}
return state.PubKey, state.FQDN, true
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
@@ -523,25 +473,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
// takes effect, and only when it was nonzero at construction;
// MaxBatchSize is construction-only and returns an error if set here.
//
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
// running yet.
func (c *Client) SetPerformance(t Performance) error {
if t.MaxBatchSize != nil {
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
}
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetPerformance(internal.Performance{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -52,10 +52,9 @@ func (m *externalChainMonitor) start() {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
done := make(chan struct{})
m.done = done
m.done = make(chan struct{})
go m.run(ctx, done)
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
@@ -73,8 +72,8 @@ func (m *externalChainMonitor) stop() {
<-done
}
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
defer close(done)
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -360,13 +360,7 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
return true
}
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
// alias by default, ensure explicit ip for localhost.
host := parsedURL.Hostname()
if host == "" {
host = "127.0.0.1"
}
addr := net.JoinHostPort(host, port)
addr := fmt.Sprintf(":%s", port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false

View File

@@ -339,7 +339,8 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
return strings.HasSuffix(qname, "."+entry.Pattern)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match

View File

@@ -164,54 +164,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard label-boundary mismatch (suffix overlap)",
handlerDomain: "*.b.test.",
queryDomain: "x.ab.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard label-boundary match",
handlerDomain: "*.b.test.",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard multi-label match",
handlerDomain: "*.b.test.",
queryDomain: "x.y.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on multi-label apex",
handlerDomain: "*.b.test.",
queryDomain: "b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on unrelated suffix containment",
handlerDomain: "*.example.com.",
queryDomain: "notexample.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard accepts pattern registered without trailing dot",
handlerDomain: "*.b.test",
queryDomain: "x.b.test.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
}
for _, tt := range tests {
@@ -321,19 +273,6 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "overlapping wildcard suffixes route to correct handler",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "app.ab.test.",
expectedCalls: 1,
expectedHandler: 1,
},
{
name: "root zone with specific domain",
handlers: []struct {

View File

@@ -26,19 +26,6 @@ type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
// client knows about and whether that peer is currently connected. The
// local resolver uses this to suppress A/AAAA answers whose RDATA points
// at a disconnected peer (typical case: a synthesized private-service
// record pointing at an embedded proxy peer that just went offline).
//
// known=false means the IP isn't in the local peerstore at all — the
// record is left alone (it points at something outside our mesh, e.g.
// a non-peer upstream).
type PeerConnectivity interface {
IsConnectedByIP(ip string) (known, connected bool)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
@@ -46,11 +33,6 @@ type Resolver struct {
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
// peerConn, when non-nil, is consulted on every A/AAAA answer to
// drop records pointing at disconnected peers. nil disables the
// filter and preserves the legacy "return whatever is registered"
// behaviour for callers that never wire a status source.
peerConn PeerConnectivity
ctx context.Context
cancel context.CancelFunc
@@ -67,15 +49,6 @@ func NewResolver() *Resolver {
}
}
// SetPeerConnectivity wires the per-IP connectivity check used to filter
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
// Safe to call multiple times; the latest value wins.
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerConn = p
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
@@ -122,7 +95,6 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
@@ -464,78 +436,6 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
// a known but disconnected peer. The synthesized private-service zones
// emit one A record per connected proxy peer in a cluster; when a peer
// goes offline, the server-side refresh removes the record from the
// next netmap, but the client may still hold the previous netmap for a
// short window. This filter is the local belt to that braces — even on
// the stale netmap, the resolver hides the offline target.
//
// Records pointing at unknown IPs (outside the local peerstore, e.g.
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
// through untouched.
//
// Escape hatch: if filtering would leave the answer empty AND at least
// one record was filtered, the original list is returned. Better to
// hand the client a record that may not respond than NXDOMAIN it
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
return records
}
d.mu.RLock()
checker := d.peerConn
d.mu.RUnlock()
if checker == nil {
return records
}
kept := make([]dns.RR, 0, len(records))
var dropped int
for _, rr := range records {
ip := extractRecordIP(rr)
if ip == "" {
kept = append(kept, rr)
continue
}
known, connected := checker.IsConnectedByIP(ip)
if known && !connected {
dropped++
continue
}
kept = append(kept, rr)
}
if dropped == 0 {
return records
}
if len(kept) == 0 {
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
return records
}
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
return kept
}
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
// an A or AAAA record, or "" for any other record type.
func extractRecordIP(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
if r.A == nil {
return ""
}
return r.A.String()
case *dns.AAAA:
if r.AAAA == nil {
return ""
}
return r.AAAA.String()
}
return ""
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()

View File

@@ -30,21 +30,6 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return nil, nil
}
// mockPeerConnectivity returns canned (known, connected) results per IP.
// Used by the disconnected-peer filter tests below. IPs not in the map
// are reported as unknown so the filter leaves them alone.
type mockPeerConnectivity struct {
byIP map[string]struct{ known, connected bool }
}
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
v, ok := m.byIP[ip]
if !ok {
return false, false
}
return v.known, v.connected
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -2667,114 +2652,3 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver.isInManagedZone(qname)
}
}
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
// connectivity-aware filtering layered on top of lookupRecords:
// when an A record's IP belongs to a known peer that's disconnected,
// the record is dropped from the answer. Records for unknown IPs pass
// through. If filtering would empty the answer entirely and at least
// one record was dropped, the original list is restored (escape hatch
// for the "all proxies offline" case).
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
zone := "svc.cluster.netbird."
connectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.10",
}
disconnectedRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "100.64.0.11",
}
unknownRec := nbdns.SimpleRecord{
Name: zone,
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 5,
RData: "203.0.113.5",
}
type ipState struct{ known, connected bool }
tests := []struct {
name string
records []nbdns.SimpleRecord
connByIP map[string]ipState
wantInOrder []string
}{
{
name: "drops disconnected peer, keeps connected",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: true},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.10"},
},
{
name: "unknown IPs pass through untouched",
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"203.0.113.5"},
},
{
name: "all disconnected falls back to original list",
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
connByIP: map[string]ipState{
"100.64.0.10": {known: true, connected: false},
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
},
{
name: "no checker wired returns all records",
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
if tc.connByIP != nil {
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
for ip, st := range tc.connByIP {
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
}
resolver.SetPeerConnectivity(cm)
}
resolver.Update([]nbdns.CustomZone{{
Domain: strings.TrimSuffix(zone, "."),
Records: tc.records,
NonAuthoritative: true,
}})
var got *dns.Msg
writer := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
got = m
return nil
},
}
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
resolver.ServeDNS(writer, req)
require.NotNil(t, got, "resolver must produce a response")
require.Len(t, got.Answer, len(tc.wantInOrder),
"answer count must match expected: %v", tc.wantInOrder)
for i, want := range tc.wantInOrder {
a, ok := got.Answer[i].(*dns.A)
require.True(t, ok, "answer[%d] must be an A record", i)
assert.Equal(t, want, a.A.String(),
"answer[%d] expected %s got %s", i, want, a.A.String())
}
})
}
}

View File

@@ -301,11 +301,6 @@ func newDefaultServer(
warningDelayBase: defaultWarningDelayBase,
healthRefresh: make(chan struct{}, 1),
}
// Wire the local resolver against the peer status recorder so it can
// suppress A/AAAA answers that point at disconnected peers (typical
// case: synthesised private-service records pointing at an embedded
// proxy peer that just went offline).
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -1391,25 +1386,3 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
}
return nil
}
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
// the local resolver can ask "is this IP a known peer and is it
// connected?" without taking on the peer package as a dependency.
// A nil status recorder always reports known=false so the resolver
// short-circuits to the legacy "return everything" path.
type localPeerConnectivity struct {
status *peer.Status
}
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
if l.status == nil {
return false, false
}
state, ok := l.status.PeerStateByIP(ip)
if !ok {
return false, false
}
return true, state.ConnStatus == peer.StatusConnected
}

View File

@@ -1967,29 +1967,6 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// Performance bundles runtime-adjustable tunnel pool knobs.
// See Engine.SetPerformance. Nil fields are ignored.
type Performance struct {
PreallocatedBuffersPerPool *uint32
}
// SetPerformance applies the given tuning to this engine's live Device.
func (e *Engine) SetPerformance(t Performance) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -66,8 +66,8 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/shared/netiputil"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, flags, err := parseRouteMessage(buf[:n])
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
@@ -66,10 +66,6 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
switch msg.Type {
case unix.RTM_ADD:
if systemops.IgnoreAddedDefaultRoute(flags) {
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
continue
}
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
@@ -82,26 +78,22 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, 0, fmt.Errorf("parse RIB: %v", err)
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
r, err := systemops.MsgToRoute(msg)
if err != nil {
return nil, 0, err
}
return r, msg.Flags, nil
return systemops.MsgToRoute(msg)
}
// waitReadable blocks until fd has data to read, or ctx is cancelled.

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -900,7 +899,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -909,6 +908,26 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -185,12 +185,9 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
return s.eventsChan
}
// Status holds a state of peers, signal, management connections and relays.
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
// every private-service request) don't contend against each other.
// Pure read methods take RLock; anything that mutates state takes Lock.
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.RWMutex
mux sync.Mutex
peers map[string]State
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
@@ -286,8 +283,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (State, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
state, ok := d.peers[peerPubKey]
if !ok {
@@ -297,8 +294,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
for _, state := range d.peers {
if state.IP == ip {
@@ -308,25 +305,6 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
return "", false
}
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -724,8 +702,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer.Clone()
}
@@ -931,8 +909,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -940,14 +918,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetLazyConnection() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -973,8 +951,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// if peer is connected to the management then login is not expired
if d.managementState {
@@ -989,8 +967,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -1000,8 +978,8 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -1030,8 +1008,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.ingressGwMgr == nil {
return nil
}
@@ -1040,16 +1018,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
@@ -1065,8 +1043,8 @@ func (d *Status) GetFullStatus() FullStatus {
LazyConnectionEnabled: d.GetLazyConnection(),
}
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
@@ -1241,8 +1219,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
}
func (d *Status) PeersStatus() (*configurer.Stats, error) {
d.mux.RLock()
defer d.mux.RUnlock()
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
}

View File

@@ -63,33 +63,6 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestStatus_PeerStateByIP(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
state, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "known tunnel IP should resolve to a peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
_, ok = status.PeerStateByIP("100.64.0.99")
req.False(ok, "unknown IP must report ok=false")
}
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
state, ok := status.PeerStateByIP("fd00::1")
req.True(ok, "IPv6-only match must resolve to the peer state")
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -179,10 +179,8 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv4zero
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
@@ -205,7 +203,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv6zero
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if runtime.GOOS == "linux" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}

View File

@@ -28,15 +28,6 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct {
ifaceName string
spk []byte
@@ -45,7 +36,7 @@ type Manager struct {
preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler
server rpServer
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
@@ -60,22 +51,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
}
func (m *Manager) GetPubKey() []byte {
@@ -89,16 +65,6 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil {
@@ -113,16 +79,6 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
}
peerID, err := m.server.AddPeer(pcfg)
if err != nil {
@@ -226,31 +182,24 @@ func (m *Manager) Run() error {
return err
}
server, err := rp.NewUDPServer(conf)
m.server, err = rp.NewUDPServer(conf)
if err != nil {
return err
}
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port)
return server.Run()
return m.server.Run()
}
// Close closes the Rosenpass server
func (m *Manager) Close() error {
m.lock.Lock()
server := m.server
m.server = nil
m.lock.Unlock()
if server == nil {
return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
if m.server != nil {
err := m.server.Close()
if err != nil {
log.Errorf("failed closing local rosenpass server")
}
m.server = nil
}
return nil
}

View File

@@ -1,412 +1,14 @@
package rosenpass
import (
"errors"
"os"
"sync"
"testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -1,42 +0,0 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -1,44 +0,0 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -1,9 +0,0 @@
//go:build dragonfly || freebsd || netbsd || openbsd
package systemops
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor.
func IgnoreAddedDefaultRoute(flags int) bool {
return filterRoutesByFlags(flags)
}

View File

@@ -1,21 +0,0 @@
//go:build darwin
package systemops
import "golang.org/x/sys/unix"
// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the
// given flags should be ignored by the network monitor. Scoped routes
// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the
// unscoped default the kernel uses for general egress, so flapping ones (e.g.
// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults)
// must not trigger an engine restart.
func IgnoreAddedDefaultRoute(flags int) bool {
if filterRoutesByFlags(flags) {
return true
}
if flags&unix.RTF_IFSCOPE != 0 {
return true
}
return false
}

View File

@@ -188,9 +188,7 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
}
doneChan := make(chan struct{})
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
// enough for teardown to finish while staying under that deadline.
timeout := time.NewTimer(20 * time.Second)
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
go func() {

View File

@@ -96,19 +96,17 @@ func (m *Manager) Stop(ctx context.Context) error {
}
m.mu.Lock()
cancel := m.cancel
done := m.done
m.mu.Unlock()
defer m.mu.Unlock()
if cancel == nil {
if m.cancel == nil {
return nil
}
cancel()
m.cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
case <-m.done:
}
return nil

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)

View File

@@ -3,14 +3,15 @@
package system
import (
"bytes"
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
"golang.org/x/sys/unix"
log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo"
@@ -28,11 +29,19 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
kernelName, kernelVersion, kernelPlatform := kernelInfo()
info := _getInfo()
for strings.Contains(info, "broken pipe") {
info = _getInfo()
time.Sleep(500 * time.Millisecond)
}
osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
osName, osVersion := readOsReleaseFile()
if osName == "" {
osName = kernelName
osName = osInfo[3]
}
systemHostname, _ := os.Hostname()
@@ -49,8 +58,8 @@ func GetInfo(ctx context.Context) *Info {
}
gio := &Info{
Kernel: kernelName,
Platform: kernelPlatform,
Kernel: osInfo[0],
Platform: osInfo[2],
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
@@ -58,7 +67,7 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: kernelVersion,
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
@@ -69,12 +78,18 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func kernelInfo() (string, string, string) {
var uts unix.Utsname
if err := unix.Uname(&uts); err != nil {
return "", "", ""
func _getInfo() string {
cmd := exec.Command("uname", "-srio")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("getInfo: %s", err)
}
return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:])
return out.String()
}
func sysInfo() (string, string, string) {

View File

@@ -6,7 +6,6 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"sync"
"syscall/js"
"time"
@@ -14,7 +13,7 @@ import (
)
const (
certValidationTimeout = 5 * time.Minute
certValidationTimeout = 60 * time.Second
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
@@ -47,31 +46,17 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool, 1)
errorChan := make(chan error, 1)
resultChan := make(chan bool)
errorChan := make(chan error)
// Release from inside the callbacks so a post-timeout promise resolution
// does not invoke an already-released func.
var thenFn, catchFn js.Func
var releaseOnce sync.Once
release := func() {
releaseOnce.Do(func() {
thenFn.Release()
catchFn.Release()
})
}
thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
resultChan <- args[0].Bool()
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
return nil
})
catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
defer release()
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
errorChan <- fmt.Errorf("certificate validation failed")
return nil
})
promise.Call("then", thenFn).Call("catch", catchFn)
}))
select {
case result := <-resultChan:

View File

@@ -11,7 +11,6 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"syscall/js"
"time"
@@ -58,8 +57,6 @@ type RDCleanPathProxy struct {
}
activeConnections map[string]*proxyConnection
destinations map[string]string
pendingHandlers map[string]js.Func
nextID atomic.Uint64
mu sync.Mutex
}
@@ -69,15 +66,8 @@ type proxyConnection struct {
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
// global handle map and MUST be released, otherwise every connection
// leaks the Go memory the closure captures.
wsHandlerFn js.Func
onMessageFn js.Func
onCloseFn js.Func
cleanupOnce sync.Once
ctx context.Context
cancel context.CancelFunc
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
@@ -90,11 +80,7 @@ func NewRDCleanPathProxy(client interface {
}
}
// CreateProxy creates a new proxy endpoint for the given destination.
// The registered handler fn and its destinations/pendingHandlers entries are
// only released once a connection is established and cleanupConnection runs.
// If a caller invokes CreateProxy but never connects to the returned URL,
// those entries stay pinned for the lifetime of the page.
// CreateProxy creates a new proxy endpoint for the given destination
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := net.JoinHostPort(hostname, port)
@@ -102,7 +88,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1))
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
p.mu.Lock()
if p.destinations == nil {
@@ -114,7 +100,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
@@ -122,14 +108,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
})
p.mu.Lock()
if p.pendingHandlers == nil {
p.pendingHandlers = make(map[string]js.Func)
}
p.pendingHandlers[proxyID] = handlerFn
p.mu.Unlock()
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn)
}))
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
@@ -163,10 +142,6 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
p.mu.Lock()
p.activeConnections[proxyID] = conn
if fn, ok := p.pendingHandlers[proxyID]; ok {
conn.wsHandlerFn = fn
delete(p.pendingHandlers, proxyID)
}
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
@@ -175,7 +150,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
@@ -183,15 +158,13 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
})
ws.Set("onGoMessage", conn.onMessageFn)
}))
conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
})
ws.Set("onGoClose", conn.onCloseFn)
}))
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
@@ -288,49 +261,25 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
conn.cleanupOnce.Do(func() {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id))
// Detach before releasing so late JS calls surface as TypeError instead
// of silent "call to released function".
if conn.wsHandlers.Truthy() {
conn.wsHandlers.Set("onGoMessage", js.Undefined())
conn.wsHandlers.Set("onGoClose", js.Undefined())
}
// wsHandlerFn may be zero-value if the pending handler lookup missed.
if conn.wsHandlerFn.Truthy() {
conn.wsHandlerFn.Release()
}
if conn.onMessageFn.Truthy() {
conn.onMessageFn.Release()
}
if conn.onCloseFn.Truthy() {
conn.onCloseFn.Release()
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
delete(p.destinations, conn.id)
delete(p.pendingHandlers, conn.id)
p.mu.Unlock()
})
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {

View File

@@ -13,7 +13,7 @@ import (
func CreateJSInterface(client *Client) js.Value {
jsInterface := js.Global().Get("Object").Call("create", js.Null())
writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf(false)
}
@@ -32,10 +32,9 @@ func CreateJSInterface(client *Client) js.Value {
_, err := client.Write(bytes)
return js.ValueOf(err == nil)
})
jsInterface.Set("write", writeFunc)
}))
resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf(false)
}
@@ -43,26 +42,14 @@ func CreateJSInterface(client *Client) js.Value {
rows := args[1].Int()
err := client.Resize(cols, rows)
return js.ValueOf(err == nil)
})
jsInterface.Set("resize", resizeFunc)
}))
closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any {
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
client.Close()
return js.Undefined()
})
jsInterface.Set("close", closeFunc)
}))
go func() {
readLoop(client, jsInterface)
// Detach before releasing so late JS calls surface as TypeError instead
// of silent "call to released function".
jsInterface.Set("write", js.Undefined())
jsInterface.Set("resize", js.Undefined())
jsInterface.Set("close", js.Undefined())
writeFunc.Release()
resizeFunc.Release()
closeFunc.Release()
}()
go readLoop(client, jsInterface)
return jsInterface
}

View File

@@ -67,10 +67,6 @@ func init() {
rootCmd.AddCommand(newTokenCommands())
}
func RootCmd() *cobra.Command {
return rootCmd
}
func Execute() error {
return rootCmd.Execute()
}
@@ -172,7 +168,7 @@ func initializeConfig() error {
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv mgmtServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
@@ -328,24 +324,19 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
return
}
if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok {
if baseServer, ok := s.(*mgmtServer.BaseServer); ok {
baseServer.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
@@ -355,32 +346,38 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Go(func() {
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
})
}()
wg.Go(func() {
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
})
}()
if stunServer != nil {
wg.Go(func() {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
})
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error {
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
@@ -494,7 +491,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
// Extract port from listen address
@@ -505,7 +502,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (m
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := newServer(
mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{
NbConfig: mgmtConfig,
DNSDomain: "",
@@ -524,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (m
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn listener.Conn)
@@ -559,10 +556,6 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, id
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Embedded IdP (Dex)
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(w, r)
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)

View File

@@ -1,13 +0,0 @@
package cmd
import (
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
)
var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server {
return mgmtServer.NewServer(cfg)
}
func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) {
newServer = fn
}

12
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5
require (
cunicu.li/go-rosenpass v0.5.42
cunicu.li/go-rosenpass v0.4.0
github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.19.0
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.4.0
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
@@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

26
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,8 +390,6 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -501,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -902,8 +900,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=

View File

@@ -112,7 +112,7 @@ func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
@@ -175,10 +175,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
continue
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
@@ -246,14 +242,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
_ = c.sendUpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
_ = c.sendUpdateAccountPeers(ctx, accountID)
})
return
}
@@ -269,7 +265,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
return c.sendUpdateAccountPeers(ctx, accountID, reason)
return c.sendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
@@ -363,14 +359,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
_ = c.sendUpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
_ = c.sendUpdateAccountPeers(ctx, accountID)
})
return
}

View File

@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
found = true
select {
case channel <- update:
log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))

View File

@@ -2,6 +2,7 @@ package manager
import (
"context"
"math/rand"
"sync"
"time"
@@ -11,44 +12,76 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
// cleanupWindow is the small grace period added on top of the
// staleness horizon before a sweep fires. It absorbs minor clock
// skew between the management server and the database and avoids
// firing a sweep right at the boundary where last_seen could still
// be one tick under the threshold.
cleanupWindow = 1 * time.Minute
// initialLoadMinDelay and initialLoadMaxDelay bracket the random
// delay applied before the post-restart catch-up query runs. Spread
// across replicas this prevents a thundering herd of catch-up
// queries hitting the database simultaneously after a deploy.
initialLoadMinDelay = 8 * time.Minute
initialLoadMaxDelay = 10 * time.Minute
)
var (
timeNow = time.Now
)
type ephemeralPeer struct {
id string
accountID string
deadline time.Time
next *ephemeralPeer
// accountEntry is the per-account state held by the cleanup tracker.
// We don't track which peers are pending — the sweep query gets the
// authoritative list straight from the database every time. We only
// need to know the latest disconnect we've observed for this account
// (so we can decide when it's safe to drop the entry) and the timer
// that will fire the next sweep.
type accountEntry struct {
lastDisconnectedAt time.Time
timer *time.Timer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
// EphemeralManager tracks accounts that may have ephemeral peers in
// need of cleanup and runs a per-account sweep at the appropriate
// time. State is in-memory and account-scoped: a sweep deletes any
// ephemeral peer in the account that has been disconnected for at
// least lifeTime, then either drops the account from the tracker
// (when no recent disconnects have arrived) or re-arms the timer.
type EphemeralManager struct {
store store.Store
peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
peersLock sync.Mutex
timer *time.Timer
accountsLock sync.Mutex
accounts map[string]*accountEntry
// initialLoadTimer is the one-shot timer used to defer the
// post-restart catch-up query; held so Stop() can cancel it.
initialLoadTimer *time.Timer
// stopped is flipped by Stop() so any timer that fires after
// teardown becomes a no-op instead of touching a half-dismantled
// store.
stopped bool
lifeTime time.Duration
cleanupWindow time.Duration
// initialLoadDelay returns the wall-clock delay to wait before
// running the post-restart catch-up query. Pluggable so tests can
// fire the load immediately.
initialLoadDelay func() time.Duration
// bgCtx is the long-lived context captured at LoadInitialPeers
// time. Timer-driven sweeps use it because they fire long after
// the original gRPC handler ctx that produced an OnPeerDisconnected
// call has been cancelled.
bgCtx context.Context
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
@@ -58,228 +91,265 @@ type EphemeralManager struct {
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
peersManager: peersManager,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
store: store,
peersManager: peersManager,
accounts: make(map[string]*accountEntry),
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
initialLoadDelay: defaultInitialLoadDelay,
}
}
// SetMetrics attaches a metrics collector. Safe to call once before
// LoadInitialPeers; later attachment is fine but earlier loads won't be
// reflected in the gauge. Pass nil to detach.
// SetMetrics attaches a metrics collector. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.peersLock.Lock()
e.accountsLock.Lock()
e.metrics = m
e.peersLock.Unlock()
e.accountsLock.Unlock()
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
// LoadInitialPeers schedules the post-restart catch-up query for a
// random moment 8-10 minutes from now. Returns immediately. The
// catch-up populates the per-account tracker from the database so any
// peers that disconnected before the restart still get cleaned up.
//
// The random delay is critical: without it, every management replica
// hitting the same Postgres instance after a deploy would issue the
// catch-up query simultaneously.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(e.lifeTime, func() {
e.cleanup(ctx)
})
}
}
// Stop timer
func (e *EphemeralManager) Stop() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.timer != nil {
e.timer.Stop()
}
}
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.bgCtx = ctx
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.removePeer(peer.ID) {
e.metrics.DecPending(1)
}
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
e.timer.Stop()
e.timer = nil
}
delay := e.initialLoadDelay()
log.WithContext(ctx).Infof("ephemeral peer initial load scheduled in %s", delay)
e.initialLoadTimer = time.AfterFunc(delay, func() {
e.loadInitialAccounts(e.bgCtx)
})
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the EphemeralLifeTime period.
// Stop cancels the deferred initial load and any per-account timers.
func (e *EphemeralManager) Stop() {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
e.stopped = true
if e.initialLoadTimer != nil {
e.initialLoadTimer.Stop()
e.initialLoadTimer = nil
}
for _, entry := range e.accounts {
if entry.timer != nil {
entry.timer.Stop()
}
}
e.accounts = make(map[string]*accountEntry)
}
// OnPeerConnected is a no-op in the account-scoped design. The sweep
// query filters out connected peers at the database level, so we don't
// need an explicit "remove from list" signal when a peer reconnects.
// Kept on the interface to preserve the existing call sites.
func (e *EphemeralManager) OnPeerConnected(_ context.Context, _ *nbpeer.Peer) {
}
// OnPeerDisconnected registers a disconnect for the peer's account and
// arms a sweep if one isn't already scheduled. Non-ephemeral peers are
// ignored.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.isPeerOnList(peer.ID) {
return
}
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
e.metrics.IncPending()
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
})
}
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := e.newDeadLine()
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
e.metrics.AddPending(int64(len(peers)))
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock()
now := timeNow()
for p := e.headPeer; p != nil; p = p.next {
if now.Before(p.deadline) {
break
}
deletePeers[p.id] = p
e.headPeer = p.next
if p.next == nil {
e.tailPeer = nil
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
if e.headPeer != nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
entry, existed := e.accounts[peer.AccountID]
if !existed {
entry = &accountEntry{}
e.accounts[peer.AccountID] = entry
e.metrics.IncPending()
}
entry.lastDisconnectedAt = now
if entry.timer == nil {
delay := e.lifeTime + e.cleanupWindow
log.WithContext(ctx).Tracef("ephemeral: scheduling sweep for account %s in %s", peer.AccountID, delay)
accountID := peer.AccountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(e.bgCtxOrFallback(ctx), accountID)
})
} else {
e.timer = nil
}
}
// bgCtxOrFallback returns the long-lived background context captured at
// LoadInitialPeers time, falling back to the supplied ctx when the
// manager hasn't been started through LoadInitialPeers (e.g. in tests
// that drive the manager directly). Must be called with the lock held
// or before the timer is armed.
func (e *EphemeralManager) bgCtxOrFallback(ctx context.Context) context.Context {
if e.bgCtx != nil {
return e.bgCtx
}
return ctx
}
// loadInitialAccounts runs the post-restart catch-up query and seeds
// the tracker with one entry per account that has at least one
// disconnected ephemeral peer.
func (e *EphemeralManager) loadInitialAccounts(ctx context.Context) {
accounts, err := e.store.GetEphemeralAccountsLastDisconnect(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral accounts on startup: %v", err)
return
}
e.peersLock.Unlock()
now := timeNow()
added := 0
// Drop the gauge by the number of entries we just took off the list,
// regardless of whether the subsequent DeletePeers call succeeds. The
// list invariant is what the gauge tracks; failed delete batches are
// counted separately via CountCleanupError so we can still see them.
if len(deletePeers) > 0 {
e.metrics.CountCleanupRun()
e.metrics.DecPending(int64(len(deletePeers)))
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
e.metrics.CountCleanupError()
for accountID, lastDisc := range accounts {
// If we already learned about this account via an
// OnPeerDisconnected that arrived during the random delay
// window, prefer the live timestamp.
if _, alreadyTracked := e.accounts[accountID]; alreadyTracked {
continue
}
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
}
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: peerID,
accountID: accountID,
deadline: deadline,
}
entry := &accountEntry{lastDisconnectedAt: lastDisc}
horizon := lastDisc.Add(e.lifeTime)
if e.headPeer == nil {
e.headPeer = ep
}
if e.tailPeer != nil {
e.tailPeer.next = ep
}
e.tailPeer = ep
}
// removePeer drops the entry from the linked list. Returns true if a
// matching entry was found and removed so callers can keep the pending
// metric gauge in sync.
func (e *EphemeralManager) removePeer(id string) bool {
if e.headPeer == nil {
return false
}
if e.headPeer.id == id {
e.headPeer = e.headPeer.next
if e.tailPeer.id == id {
e.tailPeer = nil
var delay time.Duration
if horizon.After(now) {
delay = horizon.Sub(now) + e.cleanupWindow
} else {
// Already past the staleness window — sweep right away
// (one cleanupWindow later, to keep startup load smooth
// when many accounts qualify at once).
delay = e.cleanupWindow
}
return true
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
e.accounts[accountID] = entry
added++
}
for p := e.headPeer; p.next != nil; p = p.next {
if p.next.id == id {
// if we remove the last element from the chain then set the last-1 as tail
if e.tailPeer.id == id {
e.tailPeer = p
}
p.next = p.next.next
return true
e.metrics.AddPending(int64(added))
log.WithContext(ctx).Debugf("ephemeral: loaded %d account(s) for cleanup tracking", added)
}
// sweep runs the cleanup pass for a single account. It queries the
// database for disconnected ephemeral peers that have crossed the
// staleness window, deletes them via peers.Manager, and then decides
// whether to drop the account from the tracker or re-arm the timer.
func (e *EphemeralManager) sweep(ctx context.Context, accountID string) {
now := timeNow()
e.accountsLock.Lock()
entry, ok := e.accounts[accountID]
if !ok || e.stopped {
e.accountsLock.Unlock()
return
}
lastDisc := entry.lastDisconnectedAt
entry.timer = nil
e.accountsLock.Unlock()
threshold := now.Add(-e.lifeTime)
stalePeerIDs, err := e.store.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, threshold)
if err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to query stale peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
if len(stalePeerIDs) > 0 {
log.WithContext(ctx).Tracef("ephemeral: deleting %d peer(s) for account %s", len(stalePeerIDs), accountID)
if err := e.peersManager.DeletePeers(ctx, accountID, stalePeerIDs, activity.SystemInitiator, true); err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to delete peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
e.metrics.CountCleanupRun()
e.metrics.CountPeersCleaned(int64(len(stalePeerIDs)))
}
return false
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
entry, ok = e.accounts[accountID]
if !ok {
return
}
// Drop rule: if every disconnect we've observed has now crossed
// the staleness window, the sweep we just ran saw everything that
// could possibly need cleaning. Dropping is safe — a future
// disconnect will recreate the entry. The check uses the latest
// lastDisc, which may have advanced (concurrently with the sweep
// itself) due to a new OnPeerDisconnected, in which case we
// correctly re-arm.
horizon := entry.lastDisconnectedAt.Add(e.lifeTime)
if !horizon.After(now) {
delete(e.accounts, accountID)
e.metrics.DecPending(1)
log.WithContext(ctx).Tracef("ephemeral: dropping account %s (lastDisc=%s, horizon=%s, now=%s)",
accountID, lastDisc, horizon, now)
return
}
delay := horizon.Sub(now) + e.cleanupWindow
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) isPeerOnList(id string) bool {
for p := e.headPeer; p != nil; p = p.next {
if p.id == id {
return true
}
// rearm reschedules a sweep `delay` from now. Used after a recoverable
// error in the sweep path so the account doesn't get stuck.
func (e *EphemeralManager) rearm(ctx context.Context, accountID string, delay time.Duration) {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
return false
entry, ok := e.accounts[accountID]
if !ok {
return
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) newDeadLine() time.Time {
return timeNow().Add(e.lifeTime)
// defaultInitialLoadDelay returns a random duration in
// [initialLoadMinDelay, initialLoadMaxDelay). Process-wide
// math/rand is acceptable here — the delay is purely a smoothing
// jitter, not a security primitive.
func defaultInitialLoadDelay() time.Duration {
span := int64(initialLoadMaxDelay - initialLoadMinDelay)
if span <= 0 {
return initialLoadMinDelay
}
return initialLoadMinDelay + time.Duration(rand.Int63n(span))
}

View File

@@ -2,299 +2,544 @@ package manager
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// MockStore is a thin in-memory stand-in that implements only the two
// methods the EphemeralManager uses. It honors the account / ephemeral
// / connected / lastSeen attributes of each peer so the cleanup logic
// can be exercised end-to-end without bringing up sqlite or Postgres.
type MockStore struct {
store.Store
mu sync.Mutex
account *types.Account
}
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
for _, v := range s.account.Peers {
if v.Ephemeral {
peers = append(peers, v)
func (s *MockStore) GetStaleEphemeralPeerIDsForAccount(_ context.Context, accountID string, olderThan time.Time) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.account == nil || s.account.Id != accountID {
return nil, nil
}
var ids []string
for _, p := range s.account.Peers {
if !p.Ephemeral {
continue
}
if p.Status == nil || p.Status.Connected {
continue
}
if p.Status.LastSeen.Before(olderThan) {
ids = append(ids, p.ID)
}
}
return peers, nil
return ids, nil
}
type MockAccountManager struct {
mu sync.Mutex
nbAccount.Manager
store *MockStore
deletePeerCalls int
bufferUpdateCalls map[string]int
wg *sync.WaitGroup
}
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
delete(a.store.account.Peers, peerID)
if a.wg != nil {
a.wg.Done()
func (s *MockStore) GetEphemeralAccountsLastDisconnect(_ context.Context) (map[string]time.Time, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := map[string]time.Time{}
if s.account == nil {
return out, nil
}
return nil
}
func (a *MockAccountManager) GetDeletePeerCalls() int {
a.mu.Lock()
defer a.mu.Unlock()
return a.deletePeerCalls
}
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
a.bufferUpdateCalls = make(map[string]int)
var latest time.Time
hasAny := false
for _, p := range s.account.Peers {
if !p.Ephemeral || p.Status == nil || p.Status.Connected {
continue
}
if !hasAny || p.Status.LastSeen.After(latest) {
latest = p.Status.LastSeen
hasAny = true
}
}
a.bufferUpdateCalls[accountID]++
}
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
return 0
if hasAny {
out[s.account.Id] = latest
}
return a.bufferUpdateCalls[accountID]
return out, nil
}
func (a *MockAccountManager) GetStore() store.Store {
return a.store
}
func TestNewManager(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
// withFakeClock pins timeNow to a settable value for the duration of t.
// Returns a getter and a setter so subtests can advance virtual time.
func withFakeClock(t *testing.T, start time.Time) (get func() time.Time, set func(time.Time)) {
t.Helper()
var mu sync.Mutex
now := start
timeNow = func() time.Time {
return startTime
mu.Lock()
defer mu.Unlock()
return now
}
t.Cleanup(func() { timeNow = time.Now })
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
}
return func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}, func(v time.Time) {
mu.Lock()
defer mu.Unlock()
now = v
}
}
func TestNewManagerPeerConnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
// newManagerForTest builds a manager with short timers and no random
// initial-load delay so tests run instantly.
func newManagerForTest(t *testing.T, st store.Store, peersMgr peers.Manager) *EphemeralManager {
t.Helper()
mgr := NewEphemeralManager(st, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
mgr.initialLoadDelay = func() time.Duration { return 0 }
t.Cleanup(mgr.Stop)
return mgr
}
func TestNewManagerPeerDisconnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
// TestOnPeerDisconnected_RegistersAndSweeps drives the OnPeerDisconnected
// path with a fake clock: a single ephemeral peer disconnects, we
// advance past the staleness window, and the sweep deletes it.
func TestOnPeerDisconnected_RegistersAndSweeps(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
}
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
const (
ephemeralPeers = 10
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
peersMgr := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
var deletedMu sync.Mutex
var deleted []string
var deleteCalls atomic.Int32
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, accountID string, peerIDs []string, _ string, _ bool) error {
deleteCalls.Add(1)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
mockStore.mu.Unlock()
deletedMu.Lock()
deleted = append(deleted, peerIDs...)
deletedMu.Unlock()
return nil
}).
Times(1)
}).AnyTimes()
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
mgr := newManagerForTest(t, mockStore, peersMgr)
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
// One ephemeral peer that disconnected "now".
now := getNow()
p := &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
// Advance past lifeTime + cleanupWindow so the timer-driven sweep fires.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool { return deleteCalls.Load() >= 1 }, 2*time.Second, 5*time.Millisecond,
"sweep should fire and delete the stale peer")
deletedMu.Lock()
deletedCopy := append([]string(nil), deleted...)
deletedMu.Unlock()
require.Equal(t, []string{"p1"}, deletedCopy, "only the one ephemeral peer should be deleted")
}
// TestOnPeerDisconnected_NonEphemeralIgnored: a non-ephemeral disconnect
// must not register the account or arm any timer.
func TestOnPeerDisconnected_NonEphemeralIgnored(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// No DeletePeers expectation — must not be called.
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: false,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "non-ephemeral disconnect must not register an account")
mgr.accountsLock.Unlock()
}
// TestSweep_DropsAccountWhenIdle: after a sweep cleans the stale peers,
// if no more disconnects have arrived the account must be dropped from
// the in-memory tracker.
func TestSweep_DropsAccountWhenIdle(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should be dropped after sweep with no new disconnects")
}
// TestSweep_ReArmsWhenNewDisconnectArrived: simulate the race where a
// fresh disconnect arrives just before the sweep fires. The sweep must
// observe the updated lastDisc and re-arm rather than drop.
func TestSweep_ReArmsWhenNewDisconnectArrived(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p1 := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p1.ID] = p1
mgr.OnPeerDisconnected(context.Background(), p1)
// Advance most of the way toward the first sweep, then introduce
// a fresh disconnect that resets lastDisc.
setNow(now.Add(mgr.lifeTime - 10*time.Millisecond))
p2 := &nbpeer.Peer{ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p2.ID] = p2
mgr.OnPeerDisconnected(context.Background(), p2)
// Push past p1's staleness so the first sweep runs and cleans p1
// but observes p2 already on the account entry. It must re-arm.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p1 should be cleaned at the first sweep")
// The account should still be tracked because p2 is younger than lifeTime
// from the sweep's vantage point at this moment.
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account should remain tracked because p2's disconnect kept it active")
// Push past p2's staleness; second sweep cleans p2 and drops the account.
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should drop after the final sweep")
}
// TestSweep_BatchesPeersPerAccount: many ephemeral peers disconnect on
// the same account; a single sweep must delete them all in one
// DeletePeers call.
func TestSweep_BatchesPeersPerAccount(t *testing.T) {
const ephemeralPeers = 8
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
deleteBatches := make(chan []string, 4)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
cp := append([]string(nil), peerIDs...)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
deleteBatches <- cp
return nil
}).Times(1)
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
for i := 0; i < ephemeralPeers; i++ {
id := fmt.Sprintf("p-%d", i)
// Stagger by a fraction of cleanupWindow so they all fall on
// the same sweep tick.
when := now.Add(time.Duration(i) * time.Millisecond)
p := &nbpeer.Peer{ID: id, AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: when}}
mockStore.account.Peers[id] = p
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
select {
case batch := <-deleteBatches:
require.Len(t, batch, ephemeralPeers, "all peers should be deleted in a single batch")
case <-time.After(2 * time.Second):
t.Fatal("expected one batched DeletePeers call")
}
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
// TestLoadInitialAccounts_SeedsFromStore exercises the post-restart
// catch-up path: pre-populate the store, point the manager at it, and
// confirm both already-stale and not-yet-stale peers get cleaned at
// their proper times.
func TestLoadInitialAccounts_SeedsFromStore(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: false,
}
store.account.Peers[p.ID] = p
now := getNow()
// p-stale: already past the staleness window when load runs.
mockStore.account.Peers["p-stale"] = &nbpeer.Peer{
ID: "p-stale", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now.Add(-time.Hour)},
}
// p-fresh: disconnected but not yet stale.
mockStore.account.Peers["p-fresh"] = &nbpeer.Peer{
ID: "p-fresh", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: true,
}
store.account.Peers[p.ID] = p
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
// Drive loadInitialAccounts directly with the fake-clock-aware now.
mgr.loadInitialAccounts(context.Background())
// First sweep should fire shortly (cleanupWindow) for the stale peer.
setNow(now.Add(5 * mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-stale"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-stale should be deleted on the first sweep")
// p-fresh is not yet stale; advance past its window.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-fresh"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-fresh should be deleted once it crosses the staleness window")
}
// TestStop_CancelsPendingWork verifies that Stop() cancels both the
// deferred initial load and per-account sweep timers and that
// subsequent OnPeerDisconnected calls are ignored.
func TestStop_CancelsPendingWork(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// DeletePeers must NOT be called after Stop.
mgr := NewEphemeralManager(mockStore, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
// Use a long delay so the initial-load timer is still pending.
mgr.initialLoadDelay = func() time.Duration { return time.Hour }
mgr.LoadInitialPeers(context.Background())
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.NotNil(t, mgr.initialLoadTimer, "initial-load timer should be armed")
require.Len(t, mgr.accounts, 1, "account should be tracked after disconnect")
mgr.accountsLock.Unlock()
mgr.Stop()
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "Stop should clear tracked accounts")
require.True(t, mgr.stopped, "stopped flag must be set")
mgr.accountsLock.Unlock()
// Post-stop disconnect must be ignored.
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "disconnects after Stop must be ignored")
mgr.accountsLock.Unlock()
}
// TestOnPeerConnected_IsNoop: the OnPeerConnected hook is preserved on
// the interface but does nothing in the per-account model — the sweep
// query filters connected peers at the DB level.
func TestOnPeerConnected_IsNoop(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "disconnect should track the account")
mgr.accountsLock.Unlock()
mgr.OnPeerConnected(context.Background(), &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "OnPeerConnected must be a no-op")
mgr.accountsLock.Unlock()
}
// TestSweep_StoreErrorReArms: if the stale-peer query fails, the
// account must remain tracked and a follow-up sweep gets scheduled.
func TestSweep_StoreErrorReArms(t *testing.T) {
mockStore := &erroringStore{
MockStore: MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)},
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
mockStore.fail.Store(true)
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait until the failing sweep has run at least once.
require.Eventually(t, func() bool { return mockStore.failedCalls.Load() >= 1 },
2*time.Second, 5*time.Millisecond, "expected at least one failing sweep")
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account must remain tracked after a sweep error")
// Recover and ensure the rearmed sweep cleans up.
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mockStore.fail.Store(false)
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "rearmed sweep should clean up after the store recovers")
}
// erroringStore is a MockStore that can be flipped into a failing mode
// to exercise the sweep's error-rearm path.
type erroringStore struct {
MockStore
fail atomic.Bool
failedCalls atomic.Int32
}
func (s *erroringStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
if s.fail.Load() {
s.failedCalls.Add(1)
return nil, errors.New("synthetic store error")
}
return s.MockStore.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
}
// TestDefaultInitialLoadDelay confirms the jitter falls inside the
// documented [8m, 10m) range — sanity check for the production timer.
func TestDefaultInitialLoadDelay(t *testing.T) {
for i := 0; i < 1000; i++ {
d := defaultInitialLoadDelay()
assert.GreaterOrEqual(t, d, initialLoadMinDelay)
assert.Less(t, d, initialLoadMaxDelay)
}
}
@@ -351,3 +596,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
}
return acc
}
// silence the import "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
// (still needed indirectly for ephemeral.EphemeralLifeTime in production paths).
var _ = ephemeral.EphemeralLifeTime

View File

@@ -5,7 +5,6 @@ package peers
import (
"context"
"fmt"
"net"
"time"
"github.com/rs/xid"
@@ -36,14 +35,6 @@ type Manager interface {
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
// GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP.
// Returns nil with an error when no match exists. No permission check;
// callers (the proxy's ValidateTunnelPeer RPC) are trusted server components.
GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error)
// GetPeerWithGroups returns the peer and the list of *types.Group it belongs
// to. Used by the proxy's auth path to authorise a request by the calling
// peer's group memberships.
GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error)
}
type managerImpl struct {
@@ -75,7 +66,7 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -88,7 +79,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -108,26 +99,6 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string,
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}
// GetPeerByTunnelIP delegates to the store's indexed lookup.
func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip)
}
// GetPeerWithGroups returns the peer plus its group memberships. Any store
// error returns (nil, nil, err) so callers never receive a valid peer
// alongside a non-nil error.
func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, nil, err
}
return p, groups, nil
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {

View File

@@ -6,7 +6,6 @@ package peers
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -14,7 +13,6 @@ import (
account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
types "github.com/netbirdio/netbird/management/server/types"
)
// MockManager is a mock of Manager interface.
@@ -40,20 +38,6 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}
// DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper()
@@ -113,21 +97,6 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeerByTunnelIP mocks base method.
func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP.
func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip)
}
// GetPeerID mocks base method.
func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) {
m.ctrl.T.Helper()
@@ -143,22 +112,6 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey)
}
// GetPeerWithGroups mocks base method.
func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].([]*types.Group)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPeerWithGroups indicates an expected call of GetPeerWithGroups.
func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -209,3 +162,17 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -63,7 +63,7 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}

View File

@@ -23,8 +23,6 @@ type Domain struct {
// SupportsCrowdSec is populated at query time from proxy cluster capabilities.
// Not persisted.
SupportsCrowdSec *bool `gorm:"-"`
// SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted.
SupportsPrivate *bool `gorm:"-"`
}
// EventMeta returns activity event metadata for a domain

View File

@@ -49,7 +49,6 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain,
SupportsCrowdsec: d.SupportsCrowdSec,
SupportsPrivate: d.SupportsPrivate,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster

View File

@@ -35,7 +35,6 @@ type proxyManager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -57,7 +56,7 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -94,7 +93,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster)
d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster)
ret = append(ret, d)
}
@@ -111,7 +109,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
if d.TargetCluster != "" {
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster)
cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.
@@ -122,7 +119,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -163,7 +160,7 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -187,7 +184,7 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,

View File

@@ -10,7 +10,7 @@ import (
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
@@ -40,10 +40,6 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
@@ -155,3 +151,4 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -19,7 +19,6 @@ type Manager interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)

View File

@@ -17,11 +17,10 @@ type store interface {
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -138,11 +137,6 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr)
}
// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported).
func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
return m.store.GetClusterSupportsPrivate(ctx, clusterAddr)
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
@@ -184,3 +178,4 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco
}
return nil
}

View File

@@ -15,16 +15,16 @@ import (
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
}
return nil, nil
}
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
@@ -99,9 +99,6 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")

View File

@@ -92,20 +92,6 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr)
}
// ClusterSupportsPrivate mocks base method.
func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate.
func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()

View File

@@ -20,9 +20,6 @@ type Capabilities struct {
RequireSubdomain *bool
// SupportsCrowdsec indicates whether this proxy has CrowdSec configured.
SupportsCrowdsec *bool
// Private indicates whether this proxy supports inbound access via Wireguard
// tunnel and netbird-only authentication policies
Private *bool
}
// Proxy represents a reverse proxy instance
@@ -45,34 +42,10 @@ func (Proxy) TableName() string {
return "proxies"
}
// ClusterType is the source of a proxy cluster.
type ClusterType string
const (
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
// at least one proxy row in the cluster carries a non-NULL account_id.
ClusterTypeAccount ClusterType = "account"
// ClusterTypeShared is a cluster operated by NetBird and shared across
// accounts — all proxy rows in the cluster have account_id IS NULL.
ClusterTypeShared ClusterType = "shared"
)
// Cluster represents a group of proxy nodes serving the same address.
//
// Online and ConnectedProxies derive from the same 2-min active window
// the rest of the module uses, but Cluster rows are not gated on it —
// the cluster listing surfaces offline clusters too so operators can
// see and clean them up. The 1-hour heartbeat reaper still bounds the
// table eventually.
type Cluster struct {
ID string
Address string
Type ClusterType
Online bool
ConnectedProxies int
// *bool: nil = no proxy reported the capability; the dashboard renders that as unknown.
SupportsCustomPorts *bool
RequireSubdomain *bool
SupportsCrowdSec *bool
Private *bool
SelfHosted bool
}

View File

@@ -37,7 +37,7 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
return
}
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -76,13 +76,13 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.store.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
@@ -92,7 +92,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
return
}
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -102,7 +102,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(ctx, store.LockingStrengthNone, userAuth.AccountId)
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
@@ -113,7 +113,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
@@ -123,7 +123,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -139,7 +139,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
token, err := h.store.GetProxyAccessTokenByID(ctx, store.LockingStrengthNone, tokenID)
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
@@ -154,12 +154,12 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.store.RevokeProxyAccessToken(ctx, tokenID); err != nil {
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(ctx, w, util.EmptyObject{})
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {

View File

@@ -47,7 +47,7 @@ func TestCreateToken_AccountScoped(t *testing.T) {
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
@@ -90,7 +90,7 @@ func TestCreateToken_WithExpiration(t *testing.T) {
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
@@ -115,7 +115,7 @@ func TestCreateToken_EmptyName(t *testing.T) {
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
@@ -135,7 +135,7 @@ func TestCreateToken_PermissionDenied(t *testing.T) {
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
@@ -164,7 +164,7 @@ func TestListTokens(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
@@ -202,7 +202,7 @@ func TestRevokeToken_Success(t *testing.T) {
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
@@ -231,7 +231,7 @@ func TestRevokeToken_WrongAccount(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
@@ -258,7 +258,7 @@ func TestRevokeToken_ManagementWideToken(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,

View File

@@ -9,7 +9,7 @@ import (
)
type Manager interface {
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)

View File

@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -122,6 +122,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
@@ -137,19 +152,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetClusters mocks base method.
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetClusters indicates an expected call of GetClusters.
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
@@ -182,21 +197,6 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()

View File

@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
return
}
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -196,15 +196,10 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
Type: api.ProxyClusterType(c.Type),
Online: c.Online,
ConnectedProxies: c.ConnectedProxies,
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdSec,
Private: c.Private,
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}

View File

@@ -81,8 +81,6 @@ type ClusterDeriver interface {
type CapabilityProvider interface {
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
}
type Manager struct {
@@ -114,13 +112,9 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
// GetClusters returns every proxy cluster visible to the account
// (shared + its own BYOP), regardless of whether any proxy in the
// cluster is currently heartbeating. Each cluster is enriched with the
// capability flags reported by its active proxies so the dashboard can
// render feature support without a second round-trip.
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -128,25 +122,13 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
return nil, status.NewPermissionDeniedError()
}
clusters, err := m.store.GetProxyClusters(ctx, accountID)
if err != nil {
return nil, err
}
for i := range clusters {
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address)
}
return clusters, nil
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -158,7 +140,7 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, c
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -210,9 +192,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
target.Host = resource.Domain
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
case service.TargetTypeCluster:
// Cluster targets carry the upstream address on target_id; the
// proxy resolves the destination at request time.
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
@@ -222,7 +201,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -243,7 +222,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -528,7 +507,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -784,10 +763,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeCluster:
if err := validateClusterTarget(target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
@@ -795,13 +770,6 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func validateClusterTarget(target *service.Target) error {
if !target.Options.DirectUpstream {
return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host)
}
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@@ -836,7 +804,7 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -876,7 +844,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -978,14 +946,12 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
return fmt.Errorf("failed to get services: %w", err)
}
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil

View File

@@ -1172,7 +1172,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().
@@ -1344,66 +1344,3 @@ func TestValidateSubdomainRequirement(t *testing.T) {
})
}
}
func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
// No peer or resource lookups must be issued for cluster targets.
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Options: rpservice.TargetOptions{DirectUpstream: true},
},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups")
}
// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the
// store-side check that cluster targets must opt into the host-stack dial
// path. Without DirectUpstream the proxy would route this target through
// the embedded NetBird client and fail on every request.
func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
targets := []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "backend.lan",
},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
require.Error(t, err, "cluster target without direct_upstream must be rejected")
assert.ErrorContains(t, err, "direct upstream disabled")
}
func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mgr := &Manager{store: mockStore}
svc := &rpservice.Service{
ID: "svc-1",
AccountID: accountID,
Targets: []*rpservice.Target{
{
TargetId: "eu.proxy.netbird.io",
TargetType: rpservice.TargetTypeCluster,
Host: "127.0.0.1",
},
},
}
require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup")
assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target")
}

View File

@@ -45,11 +45,10 @@ const (
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
TargetTypeCluster TargetType = "cluster"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
@@ -61,11 +60,6 @@ type TargetOptions struct {
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
// DirectUpstream bypasses the proxy's embedded NetBird client and dials
// the target via the proxy host's network stack. Useful for upstreams
// reachable without WireGuard (public APIs, LAN services, localhost
// sidecars). Default false.
DirectUpstream bool `json:"direct_upstream,omitempty"`
}
type Target struct {
@@ -73,7 +67,7 @@ type Target struct {
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port uint16 `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
@@ -206,10 +200,6 @@ type Service struct {
Mode string `gorm:"default:'http'"`
ListenPort uint16
PortAutoAssigned bool
// Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only.
Private bool
// AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO.
AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"`
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
@@ -309,12 +299,6 @@ func (s *Service) ToAPIResponse() *api.Service {
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
Private: &s.Private,
}
if len(s.AccessGroups) > 0 {
groups := append([]string(nil), s.AccessGroups...)
resp.AccessGroups = &groups
}
if s.ProxyCluster != "" {
@@ -324,7 +308,6 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp
}
// ToProtoMapping converts the service into the wire format the proxy consumes.
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := s.buildPathMappings()
@@ -366,7 +349,6 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
RewriteRedirects: s.RewriteRedirects,
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
Private: s.Private,
}
if r := restrictionsToProto(s.Restrictions); r != nil {
@@ -473,8 +455,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 &&
opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
@@ -496,22 +477,17 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
if opts.DirectUpstream {
apiOpts.DirectUpstream = &opts.DirectUpstream
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 &&
len(opts.CustomHeaders) == 0 && !opts.DirectUpstream {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
DirectUpstream: opts.DirectUpstream,
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
@@ -561,9 +537,6 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
if o.DirectUpstream != nil {
opts.DirectUpstream = *o.DirectUpstream
}
return opts, nil
}
@@ -578,14 +551,6 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if req.ListenPort != nil {
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
}
if req.Private != nil {
s.Private = *req.Private
}
if req.AccessGroups != nil {
s.AccessGroups = append([]string(nil), *req.AccessGroups...)
} else {
s.AccessGroups = nil
}
targets, err := targetsFromAPI(accountID, req.Targets)
if err != nil {
@@ -775,9 +740,6 @@ func (s *Service) Validate() error {
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
if err := s.validatePrivateRequirements(); err != nil {
return err
}
switch s.Mode {
case ModeHTTP:
@@ -791,23 +753,6 @@ func (s *Service) Validate() error {
}
}
// validatePrivateRequirements enforces the private-service contract: HTTP mode, ≥1 access group, no bearer auth.
func (s *Service) validatePrivateRequirements() error {
if !s.Private {
return nil
}
if s.Mode != "" && s.Mode != ModeHTTP {
return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode)
}
if len(s.AccessGroups) == 0 {
return errors.New("private services require at least one access group")
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive")
}
return nil
}
func (s *Service) validateHTTPMode() error {
if s.Domain == "" {
return errors.New("service domain is required")
@@ -854,21 +799,11 @@ func (s *Service) validateHTTPTargets() error {
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// Host is normally overwritten by replaceHostByLookup with the
// resolved peer IP / resource address; operator-supplied values
// are honored only when DirectUpstream is set. Validate the
// override here so misconfigured hosts fail fast at API time.
if err := validateDirectUpstreamHost(i, target); err != nil {
return err
}
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
case TargetTypeCluster:
if err := validateClusterTarget(i, target); err != nil {
return err
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
@@ -886,67 +821,25 @@ func (s *Service) validateHTTPTargets() error {
return nil
}
// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled.
func validateClusterTarget(idx int, target *Target) error {
host := strings.TrimSpace(target.Host)
if host == "" {
return fmt.Errorf("target %d: has empty host", idx)
}
if !target.Options.DirectUpstream {
return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host)
}
return validateDirectUpstreamHost(idx, target)
}
// validateDirectUpstreamHost validates the operator-supplied Host on a
// peer/host/domain target when DirectUpstream is set. Empty Host is
// allowed — the lookup fills in the default peer IP / resource address.
// Without DirectUpstream the Host value is silently overwritten by
// replaceHostByLookup, so we don't validate it (preserves the historical
// behaviour where APIs accepted any value and dropped it). Non-empty
// Host with DirectUpstream must look like a hostname or IP and must
// not carry a port (port lives on Target.Port).
func validateDirectUpstreamHost(idx int, target *Target) error {
if !target.Options.DirectUpstream {
return nil
}
host := strings.TrimSpace(target.Host)
if host == "" {
return nil
}
if strings.ContainsAny(host, " \t/") {
return fmt.Errorf("target %d: host %q contains invalid characters", idx, host)
}
if _, _, err := net.SplitHostPort(host); err == nil {
return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host)
}
return nil
}
func (s *Service) validateL4Target(target *Target) error {
// L4 services have a single target; per-target disable is meaningless
// (use the service-level Enabled flag instead). Force it on so that
// buildPathMappings always includes the target in the proto.
target.Enabled = true
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
if target.TargetType != TargetTypeCluster && target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
if err := validateDirectUpstreamHost(0, target); err != nil {
return err
}
// OK
case TargetTypeSubnet:
if target.Host == "" {
return errors.New("target host is required for subnet targets")
}
case TargetTypeCluster:
// target_id carries the cluster address; the proxy resolves
// the upstream at request time.
default:
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
}
@@ -1281,11 +1174,6 @@ func (s *Service) Copy() *Service {
}
}
var accessGroups []string
if len(s.AccessGroups) > 0 {
accessGroups = append([]string(nil), s.AccessGroups...)
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
@@ -1307,8 +1195,6 @@ func (s *Service) Copy() *Service {
Mode: s.Mode,
ListenPort: s.ListenPort,
PortAutoAssigned: s.PortAutoAssigned,
Private: s.Private,
AccessGroups: accessGroups,
}
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -1117,191 +1116,3 @@ func TestValidate_HeaderAuths(t *testing.T) {
assert.Contains(t, err.Error(), "exceeds maximum length")
})
}
func TestValidate_HTTPClusterTarget(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate")
}
func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id")
}
// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target
// rule that operator-supplied Host is mandatory: cluster targets dial the
// upstream via the host network stack (direct_upstream is implied), so an
// empty Host leaves the proxy with nothing to dial.
func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host")
}
// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second
// half of the cluster-target rule: DirectUpstream must be true so the
// stdlib transport branch in MultiTransport is taken. Without it the
// embedded NetBird client would try to dial the cluster address through
// the WG tunnel, which is the wrong network for a cluster upstream.
func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
rp := validProxy()
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {
svc := validProxy()
svc.Private = true
svc.AccessGroups = []string{"grp-admins", "grp-ops"}
cp := svc.Copy()
require.NotNil(t, cp)
assert.True(t, cp.Private)
assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups)
cp.Private = false
assert.True(t, svc.Private)
cp.AccessGroups[0] = "grp-other"
assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups)
}
func TestService_APIRoundtrip_Private(t *testing.T) {
enabled := true
private := true
accessGroups := []string{"grp-admins"}
targets := []api.ServiceTarget{{
TargetId: "eu.proxy.netbird.io",
TargetType: api.ServiceTargetTargetType("cluster"),
Protocol: "http",
Port: 80,
Enabled: true,
}}
req := &api.ServiceRequest{
Name: "svc-private",
Domain: "myapp.eu.proxy.netbird.io",
Enabled: enabled,
Private: &private,
AccessGroups: &accessGroups,
Targets: &targets,
}
svc := &Service{}
require.NoError(t, svc.FromAPIRequest(req, "acc-1"))
assert.True(t, svc.Private)
assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups)
resp := svc.ToAPIResponse()
require.NotNil(t, resp.Private)
assert.True(t, *resp.Private)
require.NotNil(t, resp.AccessGroups)
assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups)
}
func TestValidate_Private_RequiresAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "access group")
}
func TestValidate_Private_RejectsBearerAuth(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-sso"},
}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "mutually exclusive")
}
func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "http",
Host: "backend.lan",
Options: TargetOptions{DirectUpstream: true},
Enabled: true,
}}
require.NoError(t, rp.Validate())
}
func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) {
rp := validProxy()
rp.Private = true
rp.AccessGroups = []string{"grp-admins"}
rp.Mode = ModeTCP
rp.Targets = []*Target{{
TargetId: "eu.proxy.netbird.io",
TargetType: TargetTypeCluster,
Protocol: "tcp",
Enabled: true,
}}
assert.ErrorContains(t, rp.Validate(), "HTTP")
}

View File

@@ -20,20 +20,6 @@ type KeyPair struct {
type Claims struct {
jwt.RegisteredClaims
Method auth.Method `json:"method"`
// Email is the calling user's email address. Carried so the
// proxy can stamp identity on upstream requests (e.g.
// x-litellm-end-user-id) without an extra management
// round-trip on every cookie-bearing request.
Email string `json:"email,omitempty"`
// Groups carries the user's group IDs so the proxy can stamp them
// onto upstream requests (X-NetBird-Groups) from the cookie path
// without an extra management round-trip.
Groups []string `json:"groups,omitempty"`
// GroupNames carries the human-readable display names for the ids
// in Groups, ordered identically (positional pairing). Slice may be
// shorter than Groups for tokens minted before names were
// resolvable; the consumer falls back to ids for missing positions.
GroupNames []string `json:"group_names,omitempty"`
}
func GenerateKeyPair() (*KeyPair, error) {
@@ -48,13 +34,7 @@ func GenerateKeyPair() (*KeyPair, error) {
}, nil
}
// SignToken mints a session JWT for the given user and domain. email,
// groups, and groupNames, when non-empty, are embedded so the proxy can
// authorise and stamp identity for policy-aware middlewares without a
// management round-trip on every cookie-bearing request. groupNames
// pairs positionally with groups; pass nil when names couldn't be
// resolved.
func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) {
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
@@ -76,10 +56,7 @@ func SignToken(privKeyB64, userID, email, domain string, method auth.Method, gro
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
Method: method,
Email: email,
Groups: append([]string(nil), groups...),
GroupNames: append([]string(nil), groupNames...),
Method: method,
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)

View File

@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -56,7 +56,7 @@ func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID str
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -103,7 +103,7 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -151,7 +151,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -79,7 +79,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
@@ -95,7 +95,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
@@ -112,7 +112,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
@@ -134,7 +134,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
@@ -150,7 +150,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
@@ -179,7 +179,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -212,7 +212,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -235,7 +235,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -261,7 +261,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -293,7 +293,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -319,7 +319,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -354,7 +354,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -394,7 +394,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -441,7 +441,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -471,7 +471,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -503,7 +503,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -529,7 +529,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, ctx, nil)
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
@@ -545,7 +545,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)

View File

@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zone
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -56,7 +56,7 @@ func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID,
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -161,7 +161,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -80,7 +80,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
@@ -113,7 +113,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
@@ -135,7 +135,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
@@ -153,7 +153,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
@@ -181,7 +181,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -215,7 +215,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -244,7 +244,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -273,7 +273,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -297,7 +297,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -323,7 +323,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -349,7 +349,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -380,7 +380,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
@@ -445,7 +445,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, ctx, nil)
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -470,7 +470,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -500,7 +500,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, ctx, nil)
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -523,7 +523,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -549,7 +549,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, ctx, nil)
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
@@ -565,7 +565,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, ctx, nil)
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)

View File

@@ -10,10 +10,8 @@ import (
"slices"
"time"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/cors"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -21,6 +19,7 @@ import (
"google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
@@ -28,20 +27,16 @@ import (
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/crypt"
)
const apiPrefix = "/api"
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -99,17 +94,12 @@ func (s *BaseServer) Store() store.Store {
func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store {
var err error
key := s.Config.DataStoreEncryptionKey
if key == "" {
log.Debugf("generate new activity store encryption key")
key, err = crypt.GenerateKey()
if err != nil {
log.Fatalf("failed to generate event store encryption key: %v", err)
}
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
if err != nil {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
@@ -120,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -128,22 +118,6 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
// the deployment isn't using the embedded variant.
func (s *BaseServer) IDPHandler() http.Handler {
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
if !ok || embeddedIdP == nil {
return nil
}
return cors.AllowAll().Handler(embeddedIdP.Handler())
}
func (s *BaseServer) Router() *mux.Router {
return Create(s, func() *mux.Router {
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
@@ -39,7 +38,7 @@ func (s *BaseServer) JobManager() *job.Manager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := validator.NewIntegratedValidator(
integratedPeerValidator, err := integrations.NewIntegratedValidator(
context.Background(),
s.PeersManager(),
s.SettingsManager(),

View File

@@ -57,7 +57,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return permissions.NewManager(s.Store())
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
s.AfterInit(func(s *BaseServer) {
manager.SetAccountManager(s.AccountManager())
})
return manager
})
}
@@ -147,6 +153,7 @@ func (s *BaseServer) IdpManager() idp.Manager {
return idpManager
}
return nil
})
}
@@ -228,7 +235,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return &m
})
}
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
return false
}

View File

@@ -34,8 +34,6 @@ const (
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
ContainerKeyBaseServer = "baseServer"
)
type Server interface {
@@ -93,7 +91,7 @@ type Config struct {
// NewServer initializes and configures a new Server instance
func NewServer(cfg *Config) *BaseServer {
s := &BaseServer{
return &BaseServer{
Config: cfg.NbConfig,
container: make(map[string]any),
dnsDomain: cfg.DNSDomain,
@@ -106,9 +104,6 @@ func NewServer(cfg *Config) *BaseServer {
mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
}
s.container[ContainerKeyBaseServer] = s
return s
}
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
@@ -193,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
@@ -304,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
@@ -323,8 +318,6 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
gRPCHandler.ServeHTTP(writer, request)
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(writer, request)
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
idpHandler.ServeHTTP(writer, request)
default:
httpHandler.ServeHTTP(writer, request)
}

View File

@@ -6,11 +6,9 @@ import (
"net/netip"
"net/url"
"strings"
"time"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
@@ -187,38 +185,9 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
// settings == nil → field stays nil → "no info in this snapshot", client
// preserves the deadline it already had. settings non-nil → emit either a
// valid deadline or the explicit-zero "disabled" sentinel via
// encodeSessionExpiresAt.
if settings != nil {
response.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
}
return response
}
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
// representation used on LoginResponse, SyncResponse and
// ExtendAuthSessionResponse. See the proto comments on those messages.
//
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
// its anchor.
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
// UTC deadline.
//
// Returning nil ("no info, preserve client's anchor") is the caller's job —
// only meaningful on Sync builds where settings were not resolved.
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
if deadline.IsZero() {
return &timestamppb.Timestamp{}
}
return timestamppb.New(deadline)
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte

View File

@@ -5,7 +5,6 @@ import (
"net/netip"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
@@ -201,29 +200,3 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
// "expiry disabled or peer is not SSO-tracked" sentinel.
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
//
// The third state (nil pointer = "no info in this snapshot") is the caller's
// responsibility on the Sync path when settings could not be resolved; the
// helper itself never returns nil.
func TestEncodeSessionExpiresAt(t *testing.T) {
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
got := encodeSessionExpiresAt(time.Time{})
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
assert.Equal(t, int64(0), got.GetSeconds())
assert.Equal(t, int32(0), got.GetNanos())
})
t.Run("non-zero deadline round-trips", func(t *testing.T) {
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
got := encodeSessionExpiresAt(deadline)
assert.NotNil(t, got)
assert.True(t, got.AsTime().Equal(deadline))
})
}

View File

@@ -9,7 +9,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -137,12 +136,9 @@ type proxyConnection struct {
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
// syncStream is set when the proxy connected via SyncMappings.
// When non-nil, the sender goroutine uses this instead of stream.
syncStream proto.ProxyService_SyncMappingsServer
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
@@ -210,323 +206,145 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
s.proxyController = proxyController
}
// proxyConnectParams holds the validated parameters extracted from either
// a GetMappingUpdateRequest or a SyncMappingsInit message.
type proxyConnectParams struct {
proxyID string
address string
capabilities *proto.ProxyCapabilities
}
// GetMappingUpdate handles the control stream with proxy clients
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
if err != nil {
return err
}
params.capabilities = req.GetCapabilities()
ctx := stream.Context()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
stream: stream,
})
if err != nil {
return err
}
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
}
// SyncMappings implements the bidirectional SyncMappings RPC.
// It mirrors GetMappingUpdate but provides application-level back-pressure:
// management waits for an ack from the proxy before sending the next batch.
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
init, err := recvSyncInit(stream)
if err != nil {
return err
}
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
if err != nil {
return err
}
params.capabilities = init.GetCapabilities()
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
syncStream: stream,
})
if err != nil {
return err
}
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
s.cleanupFailedSnapshot(stream.Context(), conn)
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.drainRecv(stream, errChan)
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
}
// recvSyncInit receives and validates the first message on a SyncMappings stream.
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
firstMsg, err := stream.Recv()
if err != nil {
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
}
init := firstMsg.GetInit()
if init == nil {
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
}
return init, nil
}
// validateProxyConnect validates the proxy ID and address, and checks cluster
// address availability for account-scoped tokens.
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
if proxyID == "" {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
if !isProxyAddressValid(address) {
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
if err != nil {
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
}
}
return proxyConnectParams{proxyID: proxyID, address: address}, nil
}
// registerProxyConnection creates a proxyConnection, registers it with the
// proxy manager and cluster, and stores it in connectedProxies. The caller
// provides a partially initialised connSeed with stream-specific fields set;
// the remaining fields are filled in here.
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
peerInfo := PeerIPFromContext(ctx)
log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId()
if proxyID == "" {
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
proxyAddress := req.GetAddress()
if !isProxyAddressValid(proxyAddress) {
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
var tokenID string
if token := GetProxyTokenFromContext(ctx); token != nil {
if token.AccountID != nil {
accountID = token.AccountID
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
s.supersedePriorConnection(params.proxyID, sessionID)
connCtx, cancel := context.WithCancel(ctx)
connSeed.proxyID = params.proxyID
connSeed.sessionID = sessionID
connSeed.address = params.address
connSeed.accountID = accountID
connSeed.tokenID = tokenID
connSeed.capabilities = params.capabilities
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
connSeed.ctx = connCtx
connSeed.cancel = cancel
var caps *proxy.Capabilities
if c := params.capabilities; c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
Private: c.Private,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
if err != nil {
cancel()
if accountID != nil {
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(params.proxyID, connSeed)
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
}
return connSeed, proxyRecord, nil
}
// supersedePriorConnection cancels any existing connection for the given proxy.
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": newSessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
}
// cleanupFailedSnapshot removes the connection from the cluster and store
// after a snapshot send failure.
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
}
}
// drainRecv consumes and discards messages from a bidirectional stream.
// The proxy sends an ack for every incremental update; we don't need them
// after the snapshot phase. Recv errors are forwarded to errChan.
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
for {
if _, err := stream.Recv(); err != nil {
errChan <- err
return
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
if err != nil {
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
}
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
// and wait for termination. When bidi is true, normal stream closure (EOF,
// canceled) is treated as a clean disconnect rather than an error.
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": conn.proxyID,
"session_id": conn.sessionID,
"address": conn.address,
"cluster_addr": conn.address,
"account_id": conn.accountID,
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
}
defer s.disconnectProxy(conn)
go s.heartbeat(conn.ctx, conn, proxyRecord)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, conn, proxyRecord)
select {
case err := <-errChan:
if bidi && isStreamClosed(err) {
log.Infof("Proxy %s stream closed", conn.proxyID)
return nil
}
log.Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
case <-conn.ctx.Done():
log.Infof("Proxy %s context canceled", conn.proxyID)
return conn.ctx.Err()
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err()
}
}
// disconnectProxy removes the connection from cluster and store, unless it
// has already been superseded by a newer connection.
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
conn.cancel()
return
}
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
// one batch, then waits for the proxy to ack before sending the next.
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
return err
}
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
for _, m := range mappings[i:end] {
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
if err != nil {
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
}
m.AuthToken = token
}
if err := stream.Send(&proto.SyncMappingsResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
if len(mappings) == 0 {
if err := stream.Send(&proto.SyncMappingsResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
if err := waitForAck(stream); err != nil {
return err
}
}
return nil
}
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive ack: %w", err)
}
if msg.GetAck() == nil {
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
}
return nil
}
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
@@ -563,9 +381,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
if s.snapshotBatchSize <= 0 {
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
}
mappings, err := s.snapshotServiceMappings(ctx, conn)
if err != nil {
@@ -645,26 +460,12 @@ func isProxyAddressValid(addr string) bool {
return err == nil
}
// isStreamClosed returns true for errors that indicate normal stream
// termination: io.EOF, context cancellation, or gRPC Canceled.
func isStreamClosed(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return true
}
return status.Code(err) == codes.Canceled
}
// sender handles sending messages to proxy.
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
// otherwise the legacy GetMappingUpdateResponse stream is used.
// sender handles sending messages to proxy
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
select {
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
if err := conn.stream.Send(resp); err != nil {
errChan <- err
return
}
@@ -674,17 +475,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
}
}
// sendResponse sends a mapping update on whichever stream the proxy connected with.
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
if conn.syncStream != nil {
return conn.syncStream.Send(&proto.SyncMappingsResponse{
Mapping: resp.Mapping,
InitialSyncComplete: resp.InitialSyncComplete,
})
}
return conn.stream.Send(resp)
}
// SendAccessLog processes access log from proxy
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
@@ -751,15 +541,10 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService).
connUpdate = filterMappingsForProxy(conn, connUpdate)
if connUpdate == nil || len(connUpdate.Mapping) == 0 {
return true
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
@@ -888,20 +673,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
}
// proxyAcceptsMapping returns whether the proxy can receive this mapping.
// Private mappings require SupportsPrivateService; custom-port L4 mappings
// require SupportsCustomPorts. Remove operations always pass so proxies can
// clean up.
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
// Old proxies that never reported capabilities are skipped for non-TLS L4
// mappings with a custom listen port, since they don't understand the
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
// are new enough to handle the mapping. TLS uses SNI routing and works on
// any proxy. Delete operations are always sent so proxies can clean up.
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
return true
}
if mapping.GetPrivate() {
caps := conn.capabilities
if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService {
return false
}
}
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
return true
}
@@ -910,29 +691,6 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
}
// filterMappingsForProxy drops mappings the proxy cannot safely receive
// (e.g. private mappings to a proxy without SupportsPrivateService).
// Returns the input unchanged when no filtering is needed.
func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
if update == nil || len(update.Mapping) == 0 {
return update
}
kept := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, m := range update.Mapping {
if !proxyAcceptsMapping(conn, m) {
continue
}
kept = append(kept, m)
}
if len(kept) == len(update.Mapping) {
return update
}
return &proto.GetMappingUpdateResponse{
Mapping: kept,
InitialSyncComplete: update.InitialSyncComplete,
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
@@ -994,10 +752,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
authenticated, userId, method := s.authenticateRequest(ctx, req, service)
// Non-OIDC schemes (PIN/Password/Header) authenticate against per-service
// secrets and have no user-level group context, so groups stay nil. Email
// is also empty — these schemes don't resolve a user record at sign time.
token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil)
token, err := s.generateSessionToken(ctx, authenticated, service, userId, method)
if err != nil {
return nil, err
}
@@ -1086,7 +841,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -1094,11 +849,8 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
token, err := sessionkey.SignToken(
service.SessionPrivateKey,
userId,
userEmail,
service.Domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
if err != nil {
@@ -1109,26 +861,6 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
return token, nil
}
// pairGroupIDsAndNames splits a slice of resolved *types.Group records
// into parallel id and name slices. ids[i] and names[i] always pair to
// the same group. nil entries (orphan ids the manager couldn't resolve)
// are skipped so the consumer can rely on positional pairing.
func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) {
if len(groups) == 0 {
return nil, nil
}
ids = make([]string, 0, len(groups))
names = make([]string, 0, len(groups))
for _, g := range groups {
if g == nil {
continue
}
ids = append(ids, g.ID)
names = append(names, g.Name)
}
return ids, names
}
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
@@ -1393,9 +1125,7 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return verifier, redirectURL, nil
}
// GenerateSessionToken creates a signed session JWT for the given domain and
// user. The user's group memberships are embedded in the token so policy-aware
// middlewares on the proxy can authorise without an extra management round-trip.
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
@@ -1406,29 +1136,11 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
var (
email string
groupIDs []string
groupNames []string
)
if s.usersManager != nil {
user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID)
if uerr != nil {
log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr)
} else if user != nil {
email = user.Email
groupIDs, groupNames = pairGroupIDsAndNames(userGroups)
}
}
return sessionkey.SignToken(
service.SessionPrivateKey,
userID,
email,
domain,
method,
groupIDs,
groupNames,
proxyauth.DefaultSessionExpiry,
)
}
@@ -1532,7 +1244,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1545,7 +1257,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID)
user, err := s.usersManager.GetUser(ctx, userID)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
@@ -1579,15 +1291,12 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"user_id": userID,
"error": err.Error(),
}).Debug("ValidateSession: access denied")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
//nolint:nilerr
return &proto.ValidateSessionResponse{
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
Valid: false,
UserId: user.Id,
UserEmail: user.Email,
DeniedReason: "not_in_group",
}, nil
}
@@ -1597,13 +1306,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
"email": user.Email,
}).Debug("ValidateSession: access granted")
groupIDs, groupNames := pairGroupIDsAndNames(userGroups)
return &proto.ValidateSessionResponse{
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
Valid: true,
UserId: user.Id,
UserEmail: user.Email,
}, nil
}
@@ -1636,154 +1342,3 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *
}
func ptr[T any](v T) *T { return &v }
// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and
// checks the peer's group membership against the service's access groups.
// Peers without a user (machine agents, automation workloads) are first-class
// callers; authorisation runs off peer-group memberships rather than the
// optional owning user's auto-groups. On success a session JWT is minted so
// the proxy can install a cookie and skip subsequent management round-trips.
func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
domain := req.GetDomain()
tunnelIPStr := req.GetTunnelIp()
if domain == "" || tunnelIPStr == "" {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "missing domain or tunnel_ip",
}, nil
}
tunnelIP := net.ParseIP(tunnelIPStr)
if tunnelIP == nil {
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "invalid_tunnel_ip",
}, nil
}
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "service_not_found",
}, nil
}
// Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only
// validate and mint session cookies for their own account's domains.
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP)
if err != nil || peer == nil {
log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
_, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID)
if err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
DeniedReason: "peer_not_found",
}, nil
}
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
//nolint:nilerr
return &proto.ValidateTunnelPeerResponse{
Valid: false,
UserId: principalID,
UserEmail: displayIdentity,
DeniedReason: "not_in_group",
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames)
if err != nil {
return nil, err
}
log.WithFields(log.Fields{
"domain": domain,
"tunnel_ip": tunnelIPStr,
"peer_id": peer.ID,
"principal_id": principalID,
}).Debug("ValidateTunnelPeer: access granted")
return &proto.ValidateTunnelPeerResponse{
Valid: true,
UserId: principalID,
UserEmail: displayIdentity,
SessionToken: token,
PeerGroupIds: groupIDs,
PeerGroupNames: groupNames,
}, nil
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security
// boundary and must not trust upstream state). Bearer-auth services authorise
// against DistributionGroups when populated. Non-private non-bearer services
// are open.
func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error {
if service.Private {
if len(service.AccessGroups) == 0 {
return fmt.Errorf("private service has no access groups")
}
return matchAnyGroup(service.AccessGroups, peerGroupIDs)
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 {
return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs)
}
return nil
}
// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups,
// else a non-nil error.
func matchAnyGroup(allowedGroups, peerGroupIDs []string) error {
if len(allowedGroups) == 0 {
return fmt.Errorf("no allowed groups configured")
}
allowed := make(map[string]struct{}, len(allowedGroups))
for _, g := range allowedGroups {
allowed[g] = struct{}{}
}
for _, g := range peerGroupIDs {
if _, ok := allowed[g]; ok {
return nil
}
}
return fmt.Errorf("peer not in allowed groups")
}

View File

@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -129,14 +129,6 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U
return user, nil
}
func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.GetUser(ctx, userID)
if err != nil {
return nil, nil, err
}
return user, nil, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
@@ -428,46 +420,3 @@ func TestGetAccountProxyByDomain(t *testing.T) {
})
}
}
func TestCheckPeerGroupAccess(t *testing.T) {
t.Run("private with empty AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: nil}
err := checkPeerGroupAccess(svc, []string{"grp-admins"})
require.Error(t, err)
assert.Contains(t, err.Error(), "no access groups")
})
t.Run("private with peer in AccessGroups allows", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"}))
})
t.Run("private with peer outside AccessGroups denies", func(t *testing.T) {
svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}}
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"}))
})
t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) {
svc := &service.Service{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"grp-allowed"},
},
},
}
assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"}))
assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"}))
})
t.Run("non-private non-bearer is open", func(t *testing.T) {
assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil))
})
}

View File

@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil
}
log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
@@ -821,80 +821,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}, nil
}
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
// stays up; no network map sync is performed. The new deadline is returned
// in ExtendAuthSessionResponse.SessionExpiresAt.
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
extendReq := &proto.ExtendAuthSessionRequest{}
peerKey, err := s.parseRequest(ctx, req, extendReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
}
jwt := extendReq.GetJwtToken()
if jwt == "" {
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
}
var userID string
const attempts = 3
for i := 0; i < attempts; i++ {
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
if err == nil {
break
}
if i == attempts-1 {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
select {
case <-time.After(200 * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if err != nil {
return nil, err
}
if userID == "" {
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
}
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
if err != nil {
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
// Success path normally returns a non-zero deadline. A defensive zero
// would still encode as the explicit "disabled" sentinel rather than nil,
// so the client clears any stale anchor instead of preserving it.
resp := &proto.ExtendAuthSessionResponse{
SessionExpiresAt: encodeSessionExpiresAt(deadline),
}
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed processing request")
}
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed encrypting response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encrypted,
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
@@ -918,12 +844,6 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
Checks: toProtocolChecks(ctx, postureChecks),
}
// settings is always non-nil here, so we never emit nil — encoder returns
// either a valid deadline or the explicit-zero "disabled" sentinel.
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
return loginResp, nil
}

View File

@@ -1,411 +0,0 @@
package grpc
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
// sent messages and returns pre-loaded ack responses from Recv.
type syncRecordingStream struct {
grpc.ServerStream
mu sync.Mutex
sent []*proto.SyncMappingsResponse
recvMsgs []*proto.SyncMappingsRequest
recvIdx int
}
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sent = append(s.sent, m)
return nil
}
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.recvIdx >= len(s.recvMsgs) {
return nil, fmt.Errorf("no more recv messages")
}
msg := s.recvMsgs[s.recvIdx]
s.recvIdx++
return msg, nil
}
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
func (s *syncRecordingStream) SendMsg(any) error { return nil }
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
func ackMsg() *proto.SyncMappingsRequest {
return &proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
}
}
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 3
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
assert.Len(t, stream.sent[0].Mapping, 3)
assert.False(t, stream.sent[0].InitialSyncComplete)
assert.Len(t, stream.sent[1].Mapping, 3)
assert.False(t, stream.sent[1].InitialSyncComplete)
assert.Len(t, stream.sent[2].Mapping, 1)
assert.True(t, stream.sent[2].InitialSyncComplete)
// All 3 acks consumed — including the final batch.
assert.Equal(t, 3, stream.recvIdx)
}
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 100
const totalServices = 5
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 1)
assert.Len(t, stream.sent[0].Mapping, totalServices)
assert.True(t, stream.sent[0].InitialSyncComplete)
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
}
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
const cluster = "cluster.example.com"
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
s := newSnapshotTestServer(t, 500)
s.serviceManager = mgr
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
assert.Empty(t, stream.sent[0].Mapping)
assert.True(t, stream.sent[0].InitialSyncComplete)
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
}
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
// No acks available — Recv will return error.
stream := &syncRecordingStream{}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.Error(t, err)
assert.Contains(t, err.Error(), "receive ack")
// First batch should have been sent before the error.
require.Len(t, stream.sent, 1)
}
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
// Send an init message instead of an ack.
stream := &syncRecordingStream{
recvMsgs: []*proto.SyncMappingsRequest{
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
},
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.Error(t, err)
assert.Contains(t, err.Error(), "expected ack")
}
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
// Verify batches are sent strictly sequentially — batch N+1 is not sent
// until the ack for batch N is received, including the final batch.
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 6 // 3 batches, 3 acks
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
var mu sync.Mutex
var events []string
// Build a stream that logs send/recv events so we can verify ordering.
ackCh := make(chan struct{}, 3)
stream := &orderTrackingStream{
mu: &mu,
events: &events,
ackCh: ackCh,
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
// Feed acks asynchronously after a short delay to simulate real proxy.
go func() {
for range 3 {
time.Sleep(10 * time.Millisecond)
ackCh <- struct{}{}
}
}()
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
require.Len(t, events, 6)
assert.Equal(t, "send", events[0])
assert.Equal(t, "recv", events[1])
assert.Equal(t, "send", events[2])
assert.Equal(t, "recv", events[3])
assert.Equal(t, "send", events[4])
assert.Equal(t, "recv", events[5])
}
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
// an ack is signaled via ackCh.
type orderTrackingStream struct {
grpc.ServerStream
mu *sync.Mutex
events *[]string
ackCh chan struct{}
}
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
s.mu.Lock()
*s.events = append(*s.events, "send")
s.mu.Unlock()
return nil
}
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
<-s.ackCh
s.mu.Lock()
*s.events = append(*s.events, "recv")
s.mu.Unlock()
return ackMsg(), nil
}
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
func (s *orderTrackingStream) SendMsg(any) error { return nil }
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 4
const ttl = 100 * time.Millisecond
const ackDelay = 200 * time.Millisecond
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
s.tokenTTL = ttl
// Build a stream that validates tokens immediately on Send, then
// delays the ack to ensure the next batch's tokens are generated fresh.
var validateErrs []error
ackCh := make(chan struct{}, 2)
stream := &tokenValidatingSyncStream{
tokenStore: s.tokenStore,
validateErrs: &validateErrs,
ackCh: ackCh,
}
conn := &proxyConnection{
proxyID: "proxy-a",
address: cluster,
syncStream: stream,
}
go func() {
// Delay first ack so that if tokens were all generated upfront they'd expire.
time.Sleep(ackDelay)
ackCh <- struct{}{}
// Final batch ack — immediate.
ackCh <- struct{}{}
}()
err := s.sendSnapshotSync(context.Background(), conn, stream)
require.NoError(t, err)
require.Empty(t, validateErrs,
"tokens must remain valid: per-batch generation guarantees freshness")
}
type tokenValidatingSyncStream struct {
grpc.ServerStream
tokenStore *OneTimeTokenStore
validateErrs *[]error
ackCh chan struct{}
}
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
for _, mapping := range m.Mapping {
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
}
}
return nil
}
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
<-s.ackCh
return ackMsg(), nil
}
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
stream := &syncRecordingStream{}
conn := &proxyConnection{
syncStream: stream,
}
resp := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
},
InitialSyncComplete: true,
}
err := conn.sendResponse(resp)
require.NoError(t, err)
require.Len(t, stream.sent, 1)
assert.Len(t, stream.sent[0].Mapping, 1)
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
assert.True(t, stream.sent[0].InitialSyncComplete)
}
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
stream := &recordingStream{}
conn := &proxyConnection{
stream: stream,
}
resp := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
{Id: "svc-2", AccountId: "acct-2"},
},
}
err := conn.sendResponse(resp)
require.NoError(t, err)
require.Len(t, stream.messages, 1)
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
}

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
@@ -125,7 +125,6 @@ func TestValidateSession_UserAllowed(t *testing.T) {
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships")
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
@@ -146,7 +145,6 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial")
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
@@ -324,7 +322,7 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -282,7 +282,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account.
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -355,17 +355,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
// Session deadline is derived from LastLogin + PeerLoginExpiration
// on every Login/Sync response. Without a fan-out push, connected
// peers keep the deadline they received at login time and only see
// the new value after the next unrelated NetworkMap change. Add
// these two fields to the trigger list so admin-side expiry tweaks
// (e.g. shortening from 24h to 1h) reach every connected peer
// within seconds, which is what the proactive-warning feature
// relies on (see client/internal/auth/sessionwatch).
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
updateAccountPeers = true
}
@@ -855,7 +845,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -1422,7 +1412,7 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1435,7 +1425,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
// GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1448,7 +1438,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1473,7 +1463,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -1540,8 +1530,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
ctx, err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false)
if err != nil {
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return "", "", err
}
@@ -1987,7 +1976,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -2555,7 +2544,7 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
@@ -2645,7 +2634,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
// UpdatePeerIPv6 updates the IPv6 overlay address of a peer, validating it's
// within the account's v6 network range and not already taken.
func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error {
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}

Some files were not shown because too many files have changed in this diff Show More