mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-23 00:59:54 +00:00
Compare commits
10 Commits
nmap/compo
...
sha-pinnin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b59191665 | ||
|
|
0358be2313 | ||
|
|
37052fd5bc | ||
|
|
454ff66518 | ||
|
|
6137a1fcc5 | ||
|
|
4955c345d5 | ||
|
|
9192b4f029 | ||
|
|
c784b02550 | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 |
46
.github/actions/parse-semver/action.yml
vendored
Normal file
46
.github/actions/parse-semver/action.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Parse semver string
|
||||
description: Parse a refs/tags/vX.Y.Z[-prerelease][+build] ref into version components. Falls back to 0.0.0 when not run on a version tag.
|
||||
outputs:
|
||||
major:
|
||||
description: Major version number (e.g. "1" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.major }}
|
||||
minor:
|
||||
description: Minor version number (e.g. "2" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.minor }}
|
||||
patch:
|
||||
description: Patch version number (e.g. "3" from v1.2.3).
|
||||
value: ${{ steps.parse.outputs.patch }}
|
||||
prerelease:
|
||||
description: Prerelease identifier (e.g. "rc.1" from v1.2.3-rc.1), empty when absent.
|
||||
value: ${{ steps.parse.outputs.prerelease }}
|
||||
build:
|
||||
description: Build metadata (e.g. "build.7" from v1.2.3+build.7), empty when absent.
|
||||
value: ${{ steps.parse.outputs.build }}
|
||||
fullversion:
|
||||
description: MAJOR.MINOR.PATCH joined (e.g. "1.2.3").
|
||||
value: ${{ steps.parse.outputs.fullversion }}
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- id: parse
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
REF="${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}"
|
||||
if [[ "$REF" =~ ^refs/tags/v([0-9]+)\.([0-9]+)\.([0-9]+)(-([0-9A-Za-z.-]+))?(\+([0-9A-Za-z.-]+))?$ ]]; then
|
||||
MAJOR="${BASH_REMATCH[1]}"
|
||||
MINOR="${BASH_REMATCH[2]}"
|
||||
PATCH="${BASH_REMATCH[3]}"
|
||||
PRERELEASE="${BASH_REMATCH[5]}"
|
||||
BUILD="${BASH_REMATCH[7]}"
|
||||
else
|
||||
MAJOR=0; MINOR=0; PATCH=0; PRERELEASE=""; BUILD=""
|
||||
fi
|
||||
{
|
||||
echo "major=$MAJOR"
|
||||
echo "minor=$MINOR"
|
||||
echo "patch=$PATCH"
|
||||
echo "prerelease=$PRERELEASE"
|
||||
echo "build=$BUILD"
|
||||
echo "fullversion=$MAJOR.$MINOR.$PATCH"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
34
.github/dependabot.yml
vendored
Normal file
34
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
105
.github/workflows/check-license-dependencies.yml
vendored
105
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +57,55 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
6
.github/workflows/git-town.yml
vendored
6
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: git-town/action@670e1f4feb81fdef4226fc09deefe09018eb20d1 # v1.3.3
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,16 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +44,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
19
.github/workflows/golang-test-freebsd.yml
vendored
19
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,29 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
118
.github/workflows/golang-test-linux.yml
vendored
118
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +36,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +113,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +128,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -158,14 +158,14 @@ jobs:
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +177,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -231,10 +231,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +246,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +277,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +298,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +324,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +343,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +370,19 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +390,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +410,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +427,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +437,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +474,10 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +485,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +505,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +529,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +566,10 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +577,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +597,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +623,20 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +644,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
26
.github/workflows/golang-test-windows.yml
vendored
26
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,10 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +44,22 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
10
.github/workflows/golangci-lint.yml
vendored
10
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +38,13 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
2
.github/workflows/install-script-test.yml
vendored
2
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
14
.github/workflows/mobile-build-validation.yml
vendored
14
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,23 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +52,9 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
2
.github/workflows/proto-version-check.yml
vendored
2
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
|
||||
181
.github/workflows/release.yml
vendored
181
.github/workflows/release.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +51,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +100,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +131,24 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +161,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +196,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +287,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +319,25 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +378,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +407,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +421,16 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +444,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +452,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +477,24 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: ./.github/actions/parse-semver
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +514,39 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://pkgs.netbird.io/wintun/wintun-0.14.1.zip"
|
||||
$dest = "${env:downloadPath}\wintun.zip"
|
||||
$expected = "07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51"
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "wintun.zip checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
shell: pwsh
|
||||
run: |
|
||||
$url = "https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z"
|
||||
$dest = "${env:downloadPath}\mesa3d.7z"
|
||||
$expected = "71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9"
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 -o $dest $url
|
||||
$actual = (Get-FileHash $dest -Algorithm SHA256).Hash.ToLower()
|
||||
if ($actual -ne $expected) {
|
||||
throw "mesa3d.7z checksum mismatch: expected $expected, got $actual"
|
||||
}
|
||||
"file-path=$dest" | Out-File -FilePath $env:GITHUB_OUTPUT -Append -Encoding utf8
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +557,38 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 `
|
||||
-o "${{ github.workspace }}\envar_plugin.zip" `
|
||||
"https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip"
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
shell: pwsh
|
||||
run: |
|
||||
curl.exe --fail --location --retry 3 --retry-delay 2 `
|
||||
-o "${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z" `
|
||||
"https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +605,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +624,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +716,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
22
.github/workflows/test-infrastructure-files.yml
vendored
22
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,15 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +139,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +254,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
11
.github/workflows/wasm-build-validation.yml
vendored
11
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,15 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +42,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +65,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
153
README.md
153
README.md
@@ -1,147 +1,134 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
### Self-host NetBird (video)
|
||||
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
m.done = make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
|
||||
go m.run(ctx)
|
||||
go m.run(ctx, done)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
@@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() {
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
||||
defer close(m.done)
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
|
||||
@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
|
||||
@@ -61,11 +61,9 @@ import (
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
@@ -204,13 +202,6 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||
// NetworkMap that Calculate() produced from it so Step 3 incremental
|
||||
// updates have a base to apply changes against. nil for legacy-format
|
||||
// peers. Guarded by syncMsgMux.
|
||||
latestComponents *types.NetworkMapComponents
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -874,12 +865,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
@@ -920,45 +907,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
nm *mgmProto.NetworkMap
|
||||
components *types.NetworkMapComponents
|
||||
)
|
||||
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||
// Components-format peer: decode the envelope back to typed
|
||||
// components, run Calculate() locally, and convert to the wire
|
||||
// NetworkMap shape the rest of the engine consumes. Components are
|
||||
// retained so future incremental updates (Step 3) can apply deltas
|
||||
// instead of doing a full reconstruction.
|
||||
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
dnsName := ""
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||
// shared domain by stripping the peer's own label prefix. Falls
|
||||
// back to empty if the FQDN doesn't have the expected shape.
|
||||
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||
}
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode network map envelope: %w", err)
|
||||
}
|
||||
nm = result.NetworkMap
|
||||
components = result.Components
|
||||
} else {
|
||||
nm = update.GetNetworkMap()
|
||||
}
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only retain the components view when the server sent the envelope
|
||||
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||
// here would clobber a previously-cached snapshot, breaking the Step 3
|
||||
// incremental-delta base on a future envelope sync.
|
||||
if components != nil {
|
||||
e.latestComponents = components
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
e.syncRespMux.RLock()
|
||||
@@ -984,19 +937,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||
// receiving peer's FQDN — the same value the management server fills as
|
||||
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||
for i := 0; i < len(fqdn); i++ {
|
||||
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||
return fqdn[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
|
||||
@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||
// enough for teardown to finish while staying under that deadline.
|
||||
timeout := time.NewTimer(20 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -64,13 +64,6 @@
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
|
||||
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
|
||||
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKCU"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<StandardDirectory Id="CommonAppDataFolder">
|
||||
@@ -83,28 +76,10 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
|
||||
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
|
||||
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
|
||||
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
|
||||
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
|
||||
</Component>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
<ComponentRef Id="NetbirdAutoStart" />
|
||||
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
|
||||
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
|
||||
@@ -53,9 +53,6 @@ type NameServerGroup struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||
// Name group name
|
||||
Name string
|
||||
// Description group description
|
||||
|
||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
if file == "" {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return newSQLite3(file).Open(logger)
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/server/signer"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -76,7 +77,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Initialize SQLite storage
|
||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||
sqliteConfig := newSQLite3(dbPath)
|
||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
||||
stor, err := sqliteConfig.Open(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||
|
||||
@@ -55,15 +55,6 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
// componentsDisabled is the kill switch for the component-based wire
|
||||
// format. When true the controller emits legacy proto.NetworkMap to every
|
||||
// peer regardless of capability — used to roll back instantly via a
|
||||
// management restart from a bad components encoder.
|
||||
//
|
||||
// Set once in NewController from NB_NETWORK_MAP_COMPONENTS_DISABLE and
|
||||
// never written after — readers race-free without a mutex.
|
||||
componentsDisabled bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -90,30 +81,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
}
|
||||
}
|
||||
|
||||
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||
// component-based wire format for this peer. Combines the peer's advertised
|
||||
// capability with the controller-level kill switch — callers ask exactly
|
||||
// this question, so encapsulating it removes accidental double-checks.
|
||||
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||
}
|
||||
|
||||
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||
// literal — matches the convention used elsewhere in the codebase
|
||||
// (e.g. event.go's NB_TRAFFIC_EVENT_*) and reduces operator surprises.
|
||||
func parseBoolEnv(key string) bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -219,26 +192,18 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||
// the client merges it into Calculate()'s output the same
|
||||
// way the legacy server did via NetworkMap.Merge.
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
@@ -349,11 +314,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
@@ -364,12 +329,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
@@ -416,67 +376,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||
// encodes both into the wire envelope. The caller is responsible for
|
||||
// checking peer capability + componentsDisabled before dispatching here —
|
||||
// this method does NOT branch on capability itself.
|
||||
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
// Fetch the proxy network map fragment for this peer alongside the
|
||||
// components — same single-account-load path the streaming controller
|
||||
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||
// instead of waiting for the next streaming push.
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
|
||||
@@ -22,10 +22,6 @@ type Controller interface {
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
@@ -130,39 +130,6 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||
ret2, _ := ret[2].(*types.NetworkMap)
|
||||
ret3, _ := ret[3].([]*posture.Checks)
|
||||
ret4, _ := ret[4].(int64)
|
||||
ret5, _ := ret[5].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// PeerNeedsComponents mocks base method.
|
||||
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -17,7 +17,7 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
|
||||
@@ -42,10 +42,35 @@ func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// ClusterType is the source of a proxy cluster.
|
||||
type ClusterType string
|
||||
|
||||
const (
|
||||
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
|
||||
// at least one proxy row in the cluster carries a non-NULL account_id.
|
||||
ClusterTypeAccount ClusterType = "account"
|
||||
// ClusterTypeShared is a cluster operated by NetBird and shared across
|
||||
// accounts — all proxy rows in the cluster have account_id IS NULL.
|
||||
ClusterTypeShared ClusterType = "shared"
|
||||
)
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
//
|
||||
// Online and ConnectedProxies derive from the same 2-min active window
|
||||
// the rest of the module uses, but Cluster rows are not gated on it —
|
||||
// the cluster listing surfaces offline clusters too so operators can
|
||||
// see and clean them up. The 1-hour heartbeat reaper still bounds the
|
||||
// table eventually.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
|
||||
@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetAllServices mocks base method.
|
||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
// GetClusters mocks base method.
|
||||
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
// GetClusters indicates an expected call of GetClusters.
|
||||
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
@@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -196,10 +196,14 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
Type: api.ProxyClusterType(c.Type),
|
||||
Online: c.Online,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ type ClusterDeriver interface {
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -112,8 +113,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
// GetClusters returns every proxy cluster visible to the account
|
||||
// (shared + its own BYOP), regardless of whether any proxy in the
|
||||
// cluster is currently heartbeating. Each cluster is enriched with the
|
||||
// capability flags reported by its active proxies so the dashboard can
|
||||
// render feature support without a second round-trip.
|
||||
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -122,7 +127,18 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
clusters, err := m.store.GetProxyClusters(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range clusters {
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
|
||||
@@ -1,815 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||
const wgKeyRawLen = 32
|
||||
|
||||
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||
// In Step 2 the envelope is fully self-contained — every field needed by the
|
||||
// client's local Calculate() comes from the components struct itself. The
|
||||
// only externally-supplied data is the receiving peer's PeerConfig (which is
|
||||
// computed alongside the components in the network_map controller and reused
|
||||
// from the legacy proto path) and the dns_domain string.
|
||||
type ComponentsEnvelopeInput struct {
|
||||
Components *types.NetworkMapComponents
|
||||
PeerConfig *proto.PeerConfig
|
||||
DNSDomain string
|
||||
DNSForwarderPort int64
|
||||
// UserIDClaim is the OIDC claim name the client should embed in
|
||||
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||
// is OK — client treats empty as "no SshAuth to build".
|
||||
UserIDClaim string
|
||||
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||
// is present; encoder skips the field in that case.
|
||||
ProxyPatch *proto.ProxyPatch
|
||||
}
|
||||
|
||||
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||
// Go maps in their native (random) order. Indexes inside the envelope
|
||||
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||
// are self-consistent within a single encode, so the decoder reconstructs
|
||||
// the same typed objects regardless of emit order. Tests that need to
|
||||
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||
// not byte-equal.
|
||||
//
|
||||
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||
// index spaces are local to a single envelope. Delta sync (Step 3+) will
|
||||
// use a different shape for the same reason.
|
||||
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||
c := in.Components
|
||||
|
||||
// Graceful degrade when components is nil — matches the legacy path's
|
||||
// account_components.go:43 behaviour for missing/unvalidated peers
|
||||
// (return a NetworkMap with only Network populated). The receiver gets
|
||||
// an envelope it can decode without crashing; AccountSettings stays
|
||||
// non-nil so client-side dereferences are safe.
|
||||
if c == nil {
|
||||
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||
// populated (account_components.go:43). The receiver gets enough to
|
||||
// bootstrap (Network identifier, dns_domain, account_settings) and
|
||||
// nothing else.
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{
|
||||
Full: &proto.NetworkMapComponentsFull{
|
||||
PeerConfig: in.PeerConfig,
|
||||
DnsDomain: in.DNSDomain,
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
AccountSettings: &proto.AccountSettingsCompact{},
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||
// peers that exist only in c.RouterPeers would silently lose their
|
||||
// peer_index reference.
|
||||
enc := newComponentEncoder(c)
|
||||
enc.indexAllPeers()
|
||||
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||
|
||||
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||
// translate every *Policy pointer to a wire index.
|
||||
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||
|
||||
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||
// every encoder either reads from the dedup tables or works on
|
||||
// independent input.
|
||||
full := &proto.NetworkMapComponentsFull{
|
||||
Serial: networkSerial(c.Network),
|
||||
PeerConfig: in.PeerConfig,
|
||||
Network: toAccountNetwork(c.Network),
|
||||
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||
DnsDomain: in.DNSDomain,
|
||||
CustomZoneDomain: c.CustomZoneDomain,
|
||||
AgentVersions: enc.agentVersions,
|
||||
Peers: enc.peers,
|
||||
RouterPeerIndexes: routerIdxs,
|
||||
Policies: policies,
|
||||
Groups: enc.encodeGroups(),
|
||||
Routes: enc.encodeRoutes(c.Routes),
|
||||
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||
AccountZones: encodeCustomZones(c.AccountZones),
|
||||
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||
}
|
||||
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||
}
|
||||
}
|
||||
|
||||
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||
// production path always populates c.Network (account_components.go:86), but
|
||||
// the encoder is exported and a hand-built components struct may omit it.
|
||||
func networkSerial(n *types.Network) uint64 {
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
return n.CurrentSerial()
|
||||
}
|
||||
|
||||
type componentEncoder struct {
|
||||
components *types.NetworkMapComponents
|
||||
|
||||
peerOrder map[string]uint32
|
||||
peers []*proto.PeerCompact
|
||||
|
||||
agentVersionOrder map[string]uint32
|
||||
agentVersions []string
|
||||
}
|
||||
|
||||
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||
return &componentEncoder{
|
||||
components: c,
|
||||
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||
agentVersionOrder: make(map[string]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) indexAllPeers() {
|
||||
for _, p := range e.components.Peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
e.appendPeer(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := uint32(len(e.peers))
|
||||
e.peerOrder[p.ID] = idx
|
||||
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||
return idx
|
||||
}
|
||||
|
||||
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||
return idx
|
||||
}
|
||||
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||
// a WtVersion don't force the table to materialise.
|
||||
if v == "" {
|
||||
idx := uint32(len(e.agentVersions))
|
||||
if idx == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
}
|
||||
e.agentVersionOrder[""] = idx
|
||||
return idx
|
||||
}
|
||||
if len(e.agentVersions) == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
e.agentVersionOrder[""] = 0
|
||||
}
|
||||
idx := uint32(len(e.agentVersions))
|
||||
e.agentVersionOrder[v] = idx
|
||||
e.agentVersions = append(e.agentVersions, v)
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||
if len(routers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(routers))
|
||||
for _, p := range routers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, e.appendPeer(p))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||
if len(e.components.Groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||
for _, g := range e.components.Groups {
|
||||
if !g.HasSeqID() {
|
||||
continue
|
||||
}
|
||||
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||
for _, peerID := range g.Peers {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
peerIdxs = append(peerIdxs, idx)
|
||||
}
|
||||
}
|
||||
out = append(out, &proto.GroupCompact{
|
||||
Id: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
PeerIndexes: peerIdxs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||
// that list — used by encodeResourcePoliciesMap to translate
|
||||
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||
|
||||
for _, pol := range policies {
|
||||
if !pol.HasSeqID() || !pol.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, r := range pol.Rules {
|
||||
if r == nil || !r.Enabled {
|
||||
continue
|
||||
}
|
||||
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||
out = append(out, e.encodePolicyRule(pol, r))
|
||||
}
|
||||
}
|
||||
return out, idxByPolicy
|
||||
}
|
||||
|
||||
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||
return &proto.PolicyCompact{
|
||||
Id: pol.AccountSeqID,
|
||||
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||
Bidirectional: r.Bidirectional,
|
||||
Ports: portsToUint32(r.Ports),
|
||||
PortRanges: portRangesToProto(r.PortRanges),
|
||||
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||
AuthorizedUser: r.AuthorizedUser,
|
||||
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||
SourceResource: e.resourceToProto(r.SourceResource),
|
||||
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||
}
|
||||
}
|
||||
|
||||
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||
// dropping any group that has no seq id assigned.
|
||||
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(src))
|
||||
for _, gid := range src {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// unionPolicies merges c.Policies with every policy referenced by
|
||||
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||
// from the wire and the client's resource-policy lookup would come back
|
||||
// empty.
|
||||
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||
// Fast path: non-router peers have no resource-only policies, so the
|
||||
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||
if len(resourcePolicies) == 0 {
|
||||
return policies
|
||||
}
|
||||
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
for _, list := range resourcePolicies {
|
||||
for _, p := range list {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||
// group xid → local-user names) to the wire form (map keyed by group
|
||||
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||
// matches how source/destination group references handle the same case.
|
||||
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||
for groupID, names := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||
g, ok := e.components.Groups[groupID]
|
||||
if !ok || !g.HasSeqID() {
|
||||
return 0, false
|
||||
}
|
||||
return g.AccountSeqID, true
|
||||
}
|
||||
|
||||
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||
// resources the peer id is converted to a peer index into the envelope's
|
||||
// peers array. For other resource types only the type string is shipped
|
||||
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||
// for "peer" — other types fall through to group-based lookup).
|
||||
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||
if r.ID == "" && r.Type == "" {
|
||||
return nil
|
||||
}
|
||||
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||
out.PeerIndexSet = true
|
||||
out.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||
// references handle the same case.
|
||||
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(xids))
|
||||
for _, xid := range xids {
|
||||
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// networkSeq translates a Network xid to its per-account integer id using
|
||||
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||
// the xid isn't known — callers decide whether to skip the parent record.
|
||||
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||
if xid == "" {
|
||||
return 0, false
|
||||
}
|
||||
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||
if !ok || seq == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return seq, true
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := &proto.DNSSettingsCompact{
|
||||
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||
}
|
||||
for _, gid := range s.DisabledManagementGroups {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||
if len(routes) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
rr := &proto.RouteRaw{
|
||||
Id: r.AccountSeqID,
|
||||
NetId: string(r.NetID),
|
||||
Description: r.Description,
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: int32(r.NetworkType),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
SkipAutoApply: r.SkipAutoApply,
|
||||
Domains: r.Domains.ToPunycodeList(),
|
||||
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
}
|
||||
if r.Network.IsValid() {
|
||||
rr.NetworkCidr = r.Network.String()
|
||||
}
|
||||
if r.Peer != "" {
|
||||
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||
rr.PeerIndexSet = true
|
||||
rr.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||
if len(nsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||
for _, nsg := range nsgs {
|
||||
if nsg == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NameServerGroupRaw{
|
||||
Id: nsg.AccountSeqID,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Nameservers: encodeNameServers(nsg.NameServers),
|
||||
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServer, 0, len(servers))
|
||||
for _, s := range servers {
|
||||
out = append(out, &proto.NameServer{
|
||||
IP: s.IP.String(),
|
||||
NSType: int64(s.NSType),
|
||||
Port: int64(s.Port),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, &proto.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int64(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int64(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||
if len(zones) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, &proto.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: encodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||
if len(resources) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||
for _, r := range resources {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkResourceRaw{
|
||||
Id: r.AccountSeqID,
|
||||
Name: r.Name,
|
||||
Description: r.Description,
|
||||
Type: string(r.Type),
|
||||
Address: r.Address,
|
||||
DomainValue: r.Domain,
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||
entry.NetworkSeq = seq
|
||||
}
|
||||
if r.Prefix.IsValid() {
|
||||
entry.PrefixCidr = r.Prefix.String()
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||
if len(routersMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||
for networkXID, routers := range routersMap {
|
||||
if len(routers) == 0 {
|
||||
continue
|
||||
}
|
||||
netSeq, ok := e.networkSeq(networkXID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||
for peerID, r := range routers {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkRouterEntry{
|
||||
Id: r.AccountSeqID,
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
entry.PeerIndexSet = true
|
||||
entry.PeerIndex = idx
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||
if len(rpm) == 0 {
|
||||
return nil
|
||||
}
|
||||
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||
// (small slice). Network resources without seq id are dropped, matching how
|
||||
// other components-without-seq are silently filtered.
|
||||
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||
for _, r := range e.components.NetworkResources {
|
||||
if r != nil && r.AccountSeqID != 0 {
|
||||
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||
}
|
||||
}
|
||||
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||
for resourceXID, policies := range rpm {
|
||||
seq, ok := resourceXIDToSeq[resourceXID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(policies)*2)
|
||||
for _, pol := range policies {
|
||||
idxs = append(idxs, policyToIdxs[pol]...)
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||
for groupID, userIDs := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok || len(userIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSetToSlice(s map[string]struct{}) []string {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s))
|
||||
for k := range s {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||
for checkXID, failedPeerIDs := range m {
|
||||
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||
if !ok || seq == 0 {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||
for peerID := range failedPeerIDs {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||
// (which shouldn't happen in production but the encoder is exported)
|
||||
// degrades to login_expiration_enabled = false, which makes
|
||||
// LoginExpired() return false for every peer.
|
||||
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||
if s == nil {
|
||||
return &proto.AccountSettingsCompact{}
|
||||
}
|
||||
return &proto.AccountSettingsCompact{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||
}
|
||||
}
|
||||
|
||||
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
out := &proto.AccountNetwork{
|
||||
Identifier: n.Identifier,
|
||||
NetCidr: n.Net.String(),
|
||||
Dns: n.Dns,
|
||||
Serial: n.CurrentSerial(),
|
||||
}
|
||||
if len(n.NetV6.IP) > 0 {
|
||||
out.NetV6Cidr = n.NetV6.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||
pc := &proto.PeerCompact{
|
||||
WgPubKey: decodeWgKey(p.Key),
|
||||
SshPubKey: []byte(p.SSHKey),
|
||||
DnsLabel: p.DNSLabel,
|
||||
AgentVersionIdx: agentVersionIdx,
|
||||
AddedWithSsoLogin: p.UserID != "",
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
SshEnabled: p.SSHEnabled,
|
||||
SupportsIpv6: p.SupportsIPv6(),
|
||||
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||
}
|
||||
if p.LastLogin != nil {
|
||||
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||
}
|
||||
switch {
|
||||
case !p.IP.IsValid():
|
||||
// leave Ip nil
|
||||
case p.IP.Is4() || p.IP.Is4In6():
|
||||
ip := p.IP.Unmap().As4()
|
||||
pc.Ip = ip[:]
|
||||
default:
|
||||
ip := p.IP.As16()
|
||||
pc.Ip = ip[:]
|
||||
}
|
||||
if p.IPv6.IsValid() {
|
||||
ip := p.IPv6.As16()
|
||||
pc.Ipv6 = ip[:]
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||
// key, or nil for an empty / malformed key.
|
||||
func decodeWgKey(s string) []byte {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, wgKeyRawLen)
|
||||
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||
if err != nil || n != wgKeyRawLen {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portsToUint32(ports []string) []uint32 {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(ports))
|
||||
for _, p := range ports {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, uint32(v))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
out = append(out, &proto.PortInfo_Range{
|
||||
Start: uint32(r.Start),
|
||||
End: uint32(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,879 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
|
||||
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||
// input compare byte-equal via proto.Equal.
|
||||
//
|
||||
// This lives on the test side — the encoder itself emits in map-iteration
|
||||
// order. Test-side normalization is the contract for "two encodes are
|
||||
// equivalent".
|
||||
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||
if full == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||
// index 0 by convention.
|
||||
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||
if len(full.AgentVersions) > 0 {
|
||||
// Pair version → original index, sort, rebuild.
|
||||
type avEntry struct {
|
||||
version string
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]avEntry, len(full.AgentVersions))
|
||||
for i, v := range full.AgentVersions {
|
||||
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||
}
|
||||
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||
// keeps the canonicalize output stable when two entries compare
|
||||
// equal (the encoder dedups, but defending against future inputs).
|
||||
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||
if a.version == "" && b.version != "" {
|
||||
return -1
|
||||
}
|
||||
if b.version == "" && a.version != "" {
|
||||
return 1
|
||||
}
|
||||
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||
})
|
||||
newVersions := make([]string, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
avRemap[e.oldIdx] = uint32(newIdx)
|
||||
newVersions[newIdx] = e.version
|
||||
}
|
||||
full.AgentVersions = newVersions
|
||||
}
|
||||
for _, p := range full.Peers {
|
||||
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||
p.AgentVersionIdx = newIdx
|
||||
}
|
||||
}
|
||||
|
||||
type peerEntry struct {
|
||||
peer *proto.PeerCompact
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]peerEntry, len(full.Peers))
|
||||
for i, p := range full.Peers {
|
||||
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||
}
|
||||
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||
// nil from malformed keys, or both empty for placeholders).
|
||||
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||
})
|
||||
|
||||
remap := make(map[uint32]uint32, len(entries))
|
||||
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
remap[e.oldIdx] = uint32(newIdx)
|
||||
newPeers[newIdx] = e.peer
|
||||
}
|
||||
full.Peers = newPeers
|
||||
|
||||
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||
for _, g := range full.Groups {
|
||||
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||
}
|
||||
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, r := range full.Routes {
|
||||
if r.PeerIndexSet {
|
||||
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||
r.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(r.GroupIds)
|
||||
slices.Sort(r.AccessControlGroupIds)
|
||||
slices.Sort(r.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, list := range full.RoutersMap {
|
||||
for _, entry := range list.Entries {
|
||||
if entry.PeerIndexSet {
|
||||
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||
entry.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(entry.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||
}
|
||||
|
||||
for _, set := range full.PostureFailedPeers {
|
||||
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||
}
|
||||
|
||||
for _, p := range full.Policies {
|
||||
slices.Sort(p.SourceGroupIds)
|
||||
slices.Sort(p.DestinationGroupIds)
|
||||
}
|
||||
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||
// a Policy has multiple rules) still get a deterministic order. After
|
||||
// sorting we remap indexes in ResourcePoliciesMap.
|
||||
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||
for i, p := range full.Policies {
|
||||
policyOldOrder[p] = uint32(i)
|
||||
}
|
||||
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||
return c
|
||||
}
|
||||
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||
})
|
||||
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||
for newIdx, p := range full.Policies {
|
||||
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||
}
|
||||
for _, idxs := range full.ResourcePoliciesMap {
|
||||
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||
}
|
||||
for _, list := range full.GroupIdToUserIds {
|
||||
slices.Sort(list.UserIds)
|
||||
}
|
||||
slices.Sort(full.AllowedUserIds)
|
||||
}
|
||||
|
||||
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||
out := make([]uint32, 0, len(idxs))
|
||||
for _, i := range idxs {
|
||||
if newIdx, ok := remap[i]; ok {
|
||||
out = append(out, newIdx)
|
||||
}
|
||||
}
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||
// the encoder is intentionally non-deterministic.
|
||||
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||
canonicalize(a.GetFull())
|
||||
canonicalize(b.GetFull())
|
||||
return goproto.Equal(a, b)
|
||||
}
|
||||
|
||||
func newTestComponents() *types.NetworkMapComponents {
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-a",
|
||||
Key: testWgKeyA,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peera",
|
||||
SSHKey: "ssh-a",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-b",
|
||||
Key: testWgKeyB,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||
DNSLabel: "peerb",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||
}
|
||||
peerC := &nbpeer.Peer{
|
||||
ID: "peer-c",
|
||||
Key: testWgKeyC,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
return &types.NetworkMapComponents{
|
||||
PeerID: "peer-a",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 7,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 2 * time.Hour,
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-a": peerA,
|
||||
"peer-b": peerB,
|
||||
"peer-c": peerC,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "pol-1",
|
||||
AccountSeqID: 10,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||
Ports: []string{"22", "80"},
|
||||
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full, "envelope must contain Full payload")
|
||||
|
||||
assert.EqualValues(t, 7, full.Serial)
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
|
||||
require.NotNil(t, full.Network)
|
||||
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||
|
||||
require.NotNil(t, full.AccountSettings)
|
||||
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
byLabel := map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||
// run produces different wire bytes, but the canonicalized form must
|
||||
// match.
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d must be semantically equivalent to first encode", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, got := range results {
|
||||
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"goroutine %d produced inequivalent envelope", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 2)
|
||||
|
||||
groupByID := map[uint32]*proto.GroupCompact{}
|
||||
for _, g := range full.Groups {
|
||||
groupByID[g.Id] = g
|
||||
}
|
||||
require.Contains(t, groupByID, uint32(1))
|
||||
require.Contains(t, groupByID, uint32(2))
|
||||
assert.Equal(t, "Src", groupByID[1].Name)
|
||||
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.EqualValues(t, 10, pc.Id)
|
||||
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||
assert.True(t, pc.Bidirectional)
|
||||
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||
require.Len(t, pc.PortRanges, 1)
|
||||
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.RouterPeerIndexes, 1)
|
||||
idx := full.RouterPeerIndexes[0]
|
||||
require.Less(t, int(idx), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||
"two distinct versions, order depends on map iteration")
|
||||
|
||||
idxByLabel := map[string]uint32{}
|
||||
for _, p := range full.Peers {
|
||||
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||
}
|
||||
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Policies[0].Enabled = false
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Groups["group-src"].AccountSeqID = 0
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||
// canonicalize identically.
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
|
||||
var byLabel = map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
v6Only := &nbpeer.Peer{
|
||||
ID: "peer-v6",
|
||||
Key: testWgKeyA,
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||
DNSLabel: "peerv6",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.Peers["peer-v6"] = v6Only
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerv6" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||
assert.Len(t, found.Ipv6, 16)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||
ID: "peer-noip",
|
||||
Key: testWgKeyA,
|
||||
DNSLabel: "peernoip",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peernoip" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found)
|
||||
assert.Empty(t, found.Ip)
|
||||
assert.Empty(t, found.Ipv6)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
}
|
||||
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Groups)
|
||||
assert.Empty(t, full.Policies)
|
||||
assert.Empty(t, full.RouterPeerIndexes)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
c.Peers["peer-a"].UserID = "user-1"
|
||||
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||
c.Peers["peer-a"].LastLogin = &now
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var pa *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peera" {
|
||||
pa = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pa)
|
||||
assert.True(t, pa.AddedWithSsoLogin)
|
||||
assert.True(t, pa.LoginExpirationEnabled)
|
||||
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||
|
||||
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||
var pb *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerb" {
|
||||
pb = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pb)
|
||||
assert.False(t, pb.AddedWithSsoLogin)
|
||||
assert.False(t, pb.LoginExpirationEnabled)
|
||||
assert.Zero(t, pb.LastLoginUnixNano)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{
|
||||
{
|
||||
ID: "route-peer",
|
||||
AccountSeqID: 100,
|
||||
NetID: "net-A",
|
||||
Description: "via peer-c",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Peer: "peer-c", // peer ID, not WG key
|
||||
Groups: []string{"group-src"},
|
||||
AccessControlGroups: []string{"group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-peergroup",
|
||||
AccountSeqID: 101,
|
||||
NetID: "net-B",
|
||||
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||
PeerGroups: []string{"group-src", "group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-no-seq",
|
||||
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 3)
|
||||
byNetID := map[string]*proto.RouteRaw{}
|
||||
for _, r := range full.Routes {
|
||||
byNetID[r.NetId] = r
|
||||
}
|
||||
|
||||
r1 := byNetID["net-A"]
|
||||
require.NotNil(t, r1)
|
||||
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||
assert.Empty(t, r1.PeerGroupIds)
|
||||
|
||||
r2 := byNetID["net-B"]
|
||||
require.NotNil(t, r2)
|
||||
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{{
|
||||
ID: "route-x",
|
||||
AccountSeqID: 100,
|
||||
Peer: "peer-not-in-components",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Enabled: true,
|
||||
}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 1)
|
||||
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||
"missing peer reference must not pretend to point at peer index 0")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||
// is the I1 case — without unionPolicies the encoder would silently
|
||||
// drop it from the wire.
|
||||
resourceOnlyPolicy := &types.Policy{
|
||||
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
}
|
||||
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||
}
|
||||
// Resource must appear in components.NetworkResources with a seq id —
|
||||
// encoder uses that to translate the xid map key to uint32.
|
||||
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||
|
||||
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||
policyIdxByID := map[uint32]uint32{}
|
||||
for i, p := range full.Policies {
|
||||
policyByID[p.Id] = p
|
||||
policyIdxByID[p.Id] = uint32(i)
|
||||
}
|
||||
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||
|
||||
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||
require.Len(t, idxs, 2)
|
||||
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||
"resource policies map must reference both wire policy indexes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||
}},
|
||||
Groups: []string{"group-src", "group-not-persisted"},
|
||||
Primary: true, Enabled: true,
|
||||
Domains: []string{"corp.example"},
|
||||
}}
|
||||
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.NameserverGroups, 1)
|
||||
nsg := full.NameserverGroups[0]
|
||||
assert.EqualValues(t, 50, nsg.Id)
|
||||
assert.Equal(t, "Main", nsg.Name)
|
||||
assert.True(t, nsg.Primary)
|
||||
require.Len(t, nsg.Nameservers, 1)
|
||||
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||
"check-1": {
|
||||
"peer-a": {},
|
||||
"peer-b": {},
|
||||
"peer-not-in-account": {},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {
|
||||
"peer-c": {
|
||||
ID: "router-1", AccountSeqID: 200,
|
||||
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
entries := full.RoutersMap[5].Entries
|
||||
require.Len(t, entries, 1)
|
||||
e := entries[0]
|
||||
assert.EqualValues(t, 200, e.Id)
|
||||
assert.True(t, e.PeerIndexSet)
|
||||
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||
assert.True(t, e.Masquerade)
|
||||
assert.EqualValues(t, 10, e.Metric)
|
||||
assert.True(t, e.Enabled)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||
// peer_index reference must still resolve.
|
||||
c := newTestComponents()
|
||||
delete(c.Peers, "peer-c")
|
||||
routerPeer := &nbpeer.Peer{
|
||||
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||
e := full.RoutersMap[5].Entries[0]
|
||||
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.DnsSettings)
|
||||
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.GroupIDToUserIDs = map[string][]string{
|
||||
"group-src": {"user-1", "user-2"},
|
||||
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||
"group-missing": {"user-4"}, // group not in components → drop
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||
}
|
||||
|
||||
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||
}
|
||||
|
||||
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||
nm := &types.NetworkMap{
|
||||
Peers: []*nbpeer.Peer{{
|
||||
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}},
|
||||
FirewallRules: []*types.FirewallRule{{
|
||||
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||
}},
|
||||
}
|
||||
|
||||
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||
|
||||
require.NotNil(t, patch)
|
||||
assert.Len(t, patch.Peers, 1)
|
||||
assert.Len(t, patch.FirewallRules, 1)
|
||||
}
|
||||
|
||||
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||
// pass-through in both encoder branches (normal path + nil-Components
|
||||
// graceful-degrade). Without this test a regression that drops `ProxyPatch:`
|
||||
// from one of the struct literals in components_encoder.go would slip past CI.
|
||||
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||
patch := &proto.ProxyPatch{
|
||||
ForwardingRules: []*proto.ForwardingRule{{
|
||||
Protocol: proto.RuleProtocol_TCP,
|
||||
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run("normal_path", func(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
|
||||
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||
// account_components.go:43 behaviour for missing/unvalidated peers.
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
// AccountSettings deliberately nil
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||
// field is intentionally left empty — capable peers ignore it and the
|
||||
// envelope alone is the authoritative wire shape.
|
||||
//
|
||||
// PeerConfig is computed once server-side using the receiving peer's own
|
||||
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||
// client even though the server-side PeerConfig reports false.
|
||||
func ToComponentSyncResponse(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
httpConfig *nbconfig.HttpServerConfig,
|
||||
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||
peer *nbpeer.Peer,
|
||||
turnCredentials *Token,
|
||||
relayCredentials *Token,
|
||||
components *types.NetworkMapComponents,
|
||||
proxyPatch *types.NetworkMap,
|
||||
dnsName string,
|
||||
checks []*posture.Checks,
|
||||
settings *types.Settings,
|
||||
extraSettings *types.ExtraSettings,
|
||||
peerGroups []string,
|
||||
dnsFwdPort int64,
|
||||
) *proto.SyncResponse {
|
||||
network := networkOrZero(components)
|
||||
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: peerConfig,
|
||||
DNSDomain: dnsName,
|
||||
DNSForwarderPort: dnsFwdPort,
|
||||
UserIDClaim: userIDClaim,
|
||||
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||
})
|
||||
|
||||
resp := &proto.SyncResponse{
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMapEnvelope: envelope,
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||
// dereferences network.Net which would panic on nil.
|
||||
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||
if c == nil || c.Network == nil {
|
||||
return &types.Network{}
|
||||
}
|
||||
return c.Network
|
||||
}
|
||||
|
||||
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||
// patch the components envelope ships alongside. Returns nil when there are
|
||||
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||
// sees no patch and skips the merge step entirely.
|
||||
//
|
||||
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||
// delivers fragments pre-expanded — there's no raw component shape to
|
||||
// derive them from. Components purity isn't violated: proxy data isn't
|
||||
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||
// client merges it on top of its locally-computed NetworkMap.
|
||||
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
patch := &proto.ProxyPatch{
|
||||
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||
}
|
||||
if len(nm.ForwardingRules) > 0 {
|
||||
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||
for _, r := range nm.ForwardingRules {
|
||||
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||
}
|
||||
}
|
||||
return patch
|
||||
}
|
||||
|
||||
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||
// without this helper the field would be incorrectly false for any peer
|
||||
// that's the destination of an SSH-enabling policy without having
|
||||
// peer.SSHEnabled set locally.
|
||||
//
|
||||
// Mirrors the two activation paths in Calculate() (`networkmap_components.go`
|
||||
// `getPeerConnectionResources`):
|
||||
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||
// destinations.
|
||||
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||
//
|
||||
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||
// runs Calculate() over the envelope.
|
||||
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||
if c == nil || peer == nil {
|
||||
return false
|
||||
}
|
||||
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||
// exist in c.Peers, otherwise no rule applies to it.
|
||||
if _, ok := c.Peers[peer.ID]; !ok {
|
||||
return false
|
||||
}
|
||||
for _, policy := range c.Policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||
// peer itself has SSH enabled locally.
|
||||
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||
if rule == nil || !rule.Enabled {
|
||||
return false
|
||||
}
|
||||
if !peerInDestinations(c, rule, peer.ID) {
|
||||
return false
|
||||
}
|
||||
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
return true
|
||||
}
|
||||
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||
// that exactly to avoid silent divergence).
|
||||
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
return rule.DestinationResource.ID == peerID
|
||||
}
|
||||
for _, groupID := range rule.Destinations {
|
||||
if c.IsPeerInGroup(peerID, groupID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||
// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for
|
||||
// the B1 fix that the prod-DB equivalence test alone wouldn't have caught
|
||||
// if no account had this combination.
|
||||
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||
const targetPeerID = "target"
|
||||
const targetGroupID = "g_dst"
|
||||
|
||||
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||
return &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||
Groups: map[string]*types.Group{targetGroupID: group},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
}},
|
||||
}, peer
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
peerSSH bool
|
||||
rule types.PolicyRule
|
||||
wantEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22022-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-all-protocol-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-port-range-covers-22",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tcp-80-no-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "disabled-rule-skipped",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-not-in-destinations",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g_other"}, // target not in this group
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-typed-destination-resource-matches",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||
got := computeSSHEnabledForPeer(c, peer)
|
||||
assert.Equal(t, tc.wantEnabled, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||
// belt-and-suspenders presence guard mirroring Calculate's
|
||||
// getAllPeersFromGroups invariant.
|
||||
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||
c := &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||
Groups: map[string]*types.Group{
|
||||
"g": {ID: "g", Peers: []string{"missing"}},
|
||||
},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g"},
|
||||
}},
|
||||
}},
|
||||
}
|
||||
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||
"missing target peer must short-circuit to false, not consult policies")
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||
// components on graceful-degrade paths.
|
||||
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||
}
|
||||
@@ -7,18 +7,23 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
@@ -133,8 +138,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
@@ -147,19 +152,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
@@ -172,7 +177,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
@@ -183,6 +188,78 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
return response
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
@@ -200,6 +277,204 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
||||
// alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if shouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||
if config == nil || config.AuthAudience == "" {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -62,13 +61,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -100,7 +99,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -108,7 +107,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -136,9 +137,12 @@ type proxyConnection struct {
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// syncStream is set when the proxy connected via SyncMappings.
|
||||
// When non-nil, the sender goroutine uses this instead of stream.
|
||||
syncStream proto.ProxyService_SyncMappingsServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
@@ -206,145 +210,322 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// proxyConnectParams holds the validated parameters extracted from either
|
||||
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||
type proxyConnectParams struct {
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
ctx := stream.Context()
|
||||
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = req.GetCapabilities()
|
||||
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
stream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||
}
|
||||
|
||||
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||
// management waits for an ack from the proxy before sending the next batch.
|
||||
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
init, err := recvSyncInit(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = init.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
syncStream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
go s.drainRecv(stream, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||
}
|
||||
|
||||
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||
firstMsg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||
}
|
||||
init := firstMsg.GetInit()
|
||||
if init == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||
}
|
||||
return init, nil
|
||||
}
|
||||
|
||||
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||
// address availability for account-scoped tokens.
|
||||
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||
if proxyID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
if !isProxyAddressValid(address) {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
proxyAddress := req.GetAddress()
|
||||
if !isProxyAddressValid(proxyAddress) {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||
}
|
||||
}
|
||||
|
||||
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||
}
|
||||
|
||||
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||
// provides a partially initialised connSeed with stream-specific fields set;
|
||||
// the remaining fields are filled in here.
|
||||
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
|
||||
var accountID *string
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||
if token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
}
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
connSeed.proxyID = params.proxyID
|
||||
connSeed.sessionID = sessionID
|
||||
connSeed.address = params.address
|
||||
connSeed.accountID = accountID
|
||||
connSeed.tokenID = tokenID
|
||||
connSeed.capabilities = params.capabilities
|
||||
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||
connSeed.ctx = connCtx
|
||||
connSeed.cancel = cancel
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
if c := params.capabilities; c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
return connSeed, proxyRecord, nil
|
||||
}
|
||||
|
||||
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": newSessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||
// after a snapshot send failure.
|
||||
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||
// The proxy sends an ack for every incremental update; we don't need them
|
||||
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||
// canceled) is treated as a clean disconnect rather than an error.
|
||||
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"session_id": conn.sessionID,
|
||||
"address": conn.address,
|
||||
"cluster_addr": conn.address,
|
||||
"account_id": conn.accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
defer s.disconnectProxy(conn)
|
||||
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
if bidi && isStreamClosed(err) {
|
||||
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||
return nil
|
||||
}
|
||||
log.Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||
case <-conn.ctx.Done():
|
||||
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||
return conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// disconnectProxy removes the connection from cluster and store, unless it
|
||||
// has already been superseded by a newer connection.
|
||||
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||
// one batch, then waits for the proxy to ack before sending the next.
|
||||
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive ack: %w", err)
|
||||
}
|
||||
if msg.GetAck() == nil {
|
||||
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
@@ -381,6 +562,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -460,12 +644,26 @@ func isProxyAddressValid(addr string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy
|
||||
// isStreamClosed returns true for errors that indicate normal stream
|
||||
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||
func isStreamClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
return status.Code(err) == codes.Canceled
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy.
|
||||
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -475,6 +673,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||
if conn.syncStream != nil {
|
||||
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: resp.Mapping,
|
||||
InitialSyncComplete: resp.InitialSyncComplete,
|
||||
})
|
||||
}
|
||||
return conn.stream.Send(resp)
|
||||
}
|
||||
|
||||
// SendAccessLog processes access log from proxy
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
@@ -541,8 +750,8 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
|
||||
@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -932,31 +932,7 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
var plainResp *proto.SyncResponse
|
||||
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||
// computed and recompute the raw components instead. This wastes one
|
||||
// Calculate() call per initial-sync — the component-based wire
|
||||
// format is what the peer actually consumes. The streaming path
|
||||
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||
// because it dispatches by capability before computing.
|
||||
//
|
||||
// TODO(step-4-sync): refactor SyncPeer / SyncAndMarkPeer / their
|
||||
// mocks + manager interfaces to return PeerNetworkMapResult so the
|
||||
// initial-sync path stops doing duplicate work. ~13 files of churn,
|
||||
// deferred until the client-side decoder lands and there's a real
|
||||
// deployment of capability=3 peers worth optimizing for.
|
||||
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||
}
|
||||
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
} else {
|
||||
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
}
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
|
||||
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
411
management/internals/shared/grpc/sync_mappings_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records
|
||||
// sent messages and returns pre-loaded ack responses from Recv.
|
||||
type syncRecordingStream struct {
|
||||
grpc.ServerStream
|
||||
|
||||
mu sync.Mutex
|
||||
sent []*proto.SyncMappingsResponse
|
||||
recvMsgs []*proto.SyncMappingsRequest
|
||||
recvIdx int
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sent = append(s.sent, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.recvIdx >= len(s.recvMsgs) {
|
||||
return nil, fmt.Errorf("no more recv messages")
|
||||
}
|
||||
msg := s.recvMsgs[s.recvIdx]
|
||||
s.recvIdx++
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (s *syncRecordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *syncRecordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *syncRecordingStream) SendMsg(any) error { return nil }
|
||||
func (s *syncRecordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func ackMsg() *proto.SyncMappingsRequest {
|
||||
return &proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1 → 3 batches, 3 acks (one per batch, including final)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches")
|
||||
|
||||
assert.Len(t, stream.sent[0].Mapping, 3)
|
||||
assert.False(t, stream.sent[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[1].Mapping, 3)
|
||||
assert.False(t, stream.sent[1].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.sent[2].Mapping, 1)
|
||||
assert.True(t, stream.sent[2].InitialSyncComplete)
|
||||
|
||||
// All 3 acks consumed — including the final batch.
|
||||
assert.Equal(t, 3, stream.recvIdx)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, totalServices)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{ackMsg()},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.sent[0].Mapping)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4 // 2 batches → 1 ack needed, but we provide none
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// No acks available — Recv will return error.
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "receive ack")
|
||||
// First batch should have been sent before the error.
|
||||
require.Len(t, stream.sent, 1)
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
// Send an init message instead of an ack.
|
||||
stream := &syncRecordingStream{
|
||||
recvMsgs: []*proto.SyncMappingsRequest{
|
||||
{Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}},
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected ack")
|
||||
}
|
||||
|
||||
func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) {
|
||||
// Verify batches are sent strictly sequentially — batch N+1 is not sent
|
||||
// until the ack for batch N is received, including the final batch.
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6 // 3 batches, 3 acks
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []string
|
||||
|
||||
// Build a stream that logs send/recv events so we can verify ordering.
|
||||
ackCh := make(chan struct{}, 3)
|
||||
stream := &orderTrackingStream{
|
||||
mu: &mu,
|
||||
events: &events,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
// Feed acks asynchronously after a short delay to simulate real proxy.
|
||||
go func() {
|
||||
for range 3 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
ackCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Expected: send, recv-ack, send, recv-ack, send, recv-ack.
|
||||
require.Len(t, events, 6)
|
||||
assert.Equal(t, "send", events[0])
|
||||
assert.Equal(t, "recv", events[1])
|
||||
assert.Equal(t, "send", events[2])
|
||||
assert.Equal(t, "recv", events[3])
|
||||
assert.Equal(t, "send", events[4])
|
||||
assert.Equal(t, "recv", events[5])
|
||||
}
|
||||
|
||||
// orderTrackingStream logs "send" and "recv" events and blocks Recv until
|
||||
// an ack is signaled via ackCh.
|
||||
type orderTrackingStream struct {
|
||||
grpc.ServerStream
|
||||
mu *sync.Mutex
|
||||
events *[]string
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error {
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "send")
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
s.mu.Lock()
|
||||
*s.events = append(*s.events, "recv")
|
||||
s.mu.Unlock()
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *orderTrackingStream) Context() context.Context { return context.Background() }
|
||||
func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *orderTrackingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *orderTrackingStream) SendMsg(any) error { return nil }
|
||||
func (s *orderTrackingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 4
|
||||
const ttl = 100 * time.Millisecond
|
||||
const ackDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
// Build a stream that validates tokens immediately on Send, then
|
||||
// delays the ack to ensure the next batch's tokens are generated fresh.
|
||||
var validateErrs []error
|
||||
ackCh := make(chan struct{}, 2)
|
||||
stream := &tokenValidatingSyncStream{
|
||||
tokenStore: s.tokenStore,
|
||||
validateErrs: &validateErrs,
|
||||
ackCh: ackCh,
|
||||
}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Delay first ack so that if tokens were all generated upfront they'd expire.
|
||||
time.Sleep(ackDelay)
|
||||
ackCh <- struct{}{}
|
||||
// Final batch ack — immediate.
|
||||
ackCh <- struct{}{}
|
||||
}()
|
||||
|
||||
err := s.sendSnapshotSync(context.Background(), conn, stream)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid: per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
type tokenValidatingSyncStream struct {
|
||||
grpc.ServerStream
|
||||
tokenStore *OneTimeTokenStore
|
||||
validateErrs *[]error
|
||||
ackCh chan struct{}
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error {
|
||||
for _, mapping := range m.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil {
|
||||
*s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) {
|
||||
<-s.ackCh
|
||||
return ackMsg(), nil
|
||||
}
|
||||
|
||||
func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() }
|
||||
func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {}
|
||||
func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil }
|
||||
func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) {
|
||||
stream := &syncRecordingStream{}
|
||||
conn := &proxyConnection{
|
||||
syncStream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"},
|
||||
},
|
||||
InitialSyncComplete: true,
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.sent, 1)
|
||||
assert.Len(t, stream.sent[0].Mapping, 1)
|
||||
assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id)
|
||||
assert.True(t, stream.sent[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) {
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
resp := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Id: "svc-2", AccountId: "acct-2"},
|
||||
},
|
||||
}
|
||||
|
||||
err := conn.sendResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, stream.messages, 1)
|
||||
assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id)
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1621,14 +1621,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range newGroupsToCreate {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
|
||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -3036,16 +3036,6 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
|
||||
var newJWTGroup *types.Group
|
||||
for _, g := range groups {
|
||||
if g.Name == "group3" {
|
||||
newJWTGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
|
||||
@@ -96,12 +96,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -176,8 +170,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -229,12 +221,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -334,12 +320,6 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain stri
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
func (m *testServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1319,7 +1319,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Update non-existing router creates it",
|
||||
name: "Update non-existing router returns not found",
|
||||
networkId: "testNetworkId",
|
||||
routerId: "nonExistingRouterId",
|
||||
requestBody: &api.NetworkRouterRequest{
|
||||
@@ -1328,11 +1328,7 @@ func Test_NetworkRouters_Update(t *testing.T) {
|
||||
Metric: 100,
|
||||
Enabled: true,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "nonExistingRouterId", router.Id)
|
||||
},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Update router with both peer and peer_groups",
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||
// no-op once everything is consistent.
|
||||
//
|
||||
// Implemented as two table-wide SQL statements with window functions, one
|
||||
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||
// well under a second instead of the per-account-loop ~2 minutes.
|
||||
//
|
||||
// orderColumn is the column to use when assigning the deterministic ordering
|
||||
// (typically the primary-key string id).
|
||||
func BackfillAccountSeqIDs[T any](
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
entity types.AccountSeqEntity,
|
||||
orderColumn string,
|
||||
) error {
|
||||
var model T
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
table := quoteIdent(db, stmt.Schema.Table)
|
||||
orderCol := quoteIdent(db, orderColumn)
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var pending int64
|
||||
if err := tx.Raw(
|
||||
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||
).Scan(&pending).Error; err != nil {
|
||||
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||
}
|
||||
|
||||
if pending > 0 {
|
||||
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||
return fmt.Errorf("rank %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func quoteIdent(db *gorm.DB, name string) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql":
|
||||
return "`" + name + "`"
|
||||
case "postgres":
|
||||
return `"` + name + `"`
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres", "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
WITH max_seq AS (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
),
|
||||
ranked AS (
|
||||
SELECT p.id,
|
||||
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||
FROM %s p
|
||||
JOIN max_seq m ON p.account_id = m.account_id
|
||||
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||
)
|
||||
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||
FROM ranked
|
||||
WHERE %s.id = ranked.id
|
||||
`, table, orderCol, table, table, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
UPDATE %s p
|
||||
JOIN (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
) m ON p.account_id = m.account_id
|
||||
JOIN (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||
) r ON p.id = r.id
|
||||
SET p.account_seq_id = m.max_seq + r.rn
|
||||
`, table, table, orderCol, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql).Error
|
||||
}
|
||||
|
||||
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`, table)
|
||||
case "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql, string(entity)).Error
|
||||
}
|
||||
@@ -198,7 +198,11 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to insert account")
|
||||
|
||||
account.PeersG = []nbpeer.Peer{
|
||||
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
{
|
||||
AccountID: "1234",
|
||||
Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}},
|
||||
Status: &nbpeer.PeerStatus{LastSeen: time.Now()},
|
||||
},
|
||||
}
|
||||
|
||||
err = db.Save(account).Error
|
||||
|
||||
@@ -69,12 +69,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNSGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -126,8 +120,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -71,20 +71,9 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
network.ID = xid.New().String()
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network seq id: %w", err)
|
||||
}
|
||||
network.AccountSeqID = seq
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
||||
@@ -113,25 +102,14 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
network.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||
|
||||
return network, nil
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
|
||||
@@ -34,8 +34,11 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
|
||||
|
||||
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, networks, 1)
|
||||
require.Equal(t, "testNetworkId", networks[0].ID)
|
||||
ids := make([]string, 0, len(networks))
|
||||
for _, n := range networks {
|
||||
ids = append(ids, n.ID)
|
||||
}
|
||||
require.ElementsMatch(t, []string{"testNetworkId", "secondNetworkId"}, ids)
|
||||
}
|
||||
|
||||
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
|
||||
@@ -252,73 +255,3 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedNetwork)
|
||||
}
|
||||
|
||||
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
|
||||
// non-zero AccountSeqID on the persisted network (allocated through the
|
||||
// account_seq_counters table).
|
||||
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-allocation-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
|
||||
// AccountSeqID even when the caller passes a zero value (the shape REST
|
||||
// handlers produce because the field is `json:"-"`).
|
||||
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-preserve-original",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &types.Network{
|
||||
AccountID: accountID,
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = manager.UpdateNetwork(ctx, userID, update)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -125,12 +125,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network resource seq id: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -237,7 +231,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = oldResource.AccountSeqID
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,9 +32,6 @@ type NetworkResource struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
|
||||
Name string
|
||||
Description string
|
||||
Type NetworkResourceType
|
||||
@@ -96,18 +93,17 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
|
||||
|
||||
func (n *NetworkResource) Copy() *NetworkResource {
|
||||
return &NetworkResource{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,13 +102,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", err)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.CreateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
}
|
||||
@@ -168,27 +162,20 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
if network.ID != router.NetworkID {
|
||||
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
|
||||
if existing.AccountID != router.AccountID {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
if existing.NetworkID != router.NetworkID {
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
|
||||
if err == nil {
|
||||
router.AccountSeqID = oldRouter.AccountSeqID
|
||||
} else if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
|
||||
// PUT-as-upsert: caller may target a brand-new router id (used by
|
||||
// the dashboard's "save" flow). Allocate a fresh account_seq_id so
|
||||
// the upsert behaves the same as Create().
|
||||
seq, allocErr := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if allocErr != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", allocErr)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
} else {
|
||||
return fmt.Errorf("failed to get existing network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
err = transaction.UpdateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -211,6 +211,102 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
require.Equal(t, router.Metric, updatedRouter.Metric)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to update a router that belongs to otherAccountId
|
||||
// by passing the other account's router ID through the URL.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "otherRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
// The other account's router must be untouched.
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "otherAccountId", "otherRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "otherAccountId", stored.AccountID)
|
||||
require.Equal(t, "otherNetworkId", stored.NetworkID)
|
||||
require.Equal(t, "otherPeer", stored.Peer)
|
||||
require.Equal(t, 1, stored.Metric)
|
||||
}
|
||||
|
||||
func Test_CreateRouterRejectsCrossAccountID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// Admin of testAccountId tries to create a router in otherAccountId's network.
|
||||
// The permission check is on router.AccountID (their own), but the network
|
||||
// lookup must fail because (testAccountId, otherNetworkId) does not exist.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "otherNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
createdRouter, err := manager.CreateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, createdRouter)
|
||||
|
||||
// No router should have been created in either account's scope under otherNetworkId.
|
||||
routersInOther, err := s.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, "otherAccountId", "otherNetworkId")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routersInOther, 1)
|
||||
require.Equal(t, "otherRouterId", routersInOther[0].ID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterRejectsNetworkMismatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testAdminId"
|
||||
|
||||
// The router exists in testNetworkId, but the caller submits secondNetworkId
|
||||
// (a different network in the same account). The update must be refused.
|
||||
router, err := types.NewNetworkRouter("testAccountId", "secondNetworkId", "testPeerId", []string{}, false, 1, true)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
am := mock_server.MockAccountManager{}
|
||||
manager := NewManager(s, permissionsManager, &am)
|
||||
|
||||
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedRouter)
|
||||
|
||||
stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "testAccountId", "testRouterId")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testNetworkId", stored.NetworkID)
|
||||
}
|
||||
|
||||
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
userID := "testUserId"
|
||||
|
||||
@@ -13,9 +13,6 @@ type NetworkRouter struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
Masquerade bool
|
||||
@@ -81,15 +78,14 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
|
||||
|
||||
func (n *NetworkRouter) Copy() *NetworkRouter {
|
||||
return &NetworkRouter{
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,24 +7,12 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
|
||||
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the network has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip networks that return false here.
|
||||
func (n *Network) HasSeqID() bool {
|
||||
return n != nil && n.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func NewNetwork(accountId, name, description string) *Network {
|
||||
return &Network{
|
||||
ID: xid.New().String(),
|
||||
@@ -53,14 +41,13 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a copy of a network.
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,9 +13,8 @@ import (
|
||||
|
||||
// Peer capability constants mirror the proto enum values.
|
||||
const (
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilityComponentNetworkMap int32 = 3
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
)
|
||||
|
||||
// Peer represents a machine connected to the network.
|
||||
@@ -87,7 +86,7 @@ type PeerStatus struct { //nolint:revive
|
||||
// active session". Integer nanoseconds are used so equality is
|
||||
// precision-safe across drivers, and so the predicates compose to a
|
||||
// single bigint comparison.
|
||||
SessionStartedAt int64
|
||||
SessionStartedAt int64 `gorm:"not null;default:0"`
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
// LoginExpired
|
||||
@@ -248,14 +247,6 @@ func (p *Peer) SupportsSourcePrefixes() bool {
|
||||
return p.HasCapability(PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
|
||||
// SupportsComponentNetworkMap reports whether the peer assembles its
|
||||
// NetworkMap from server-shipped components instead of consuming a fully
|
||||
// expanded NetworkMap. Determines whether the network_map controller skips
|
||||
// Calculate() server-side and emits the components envelope.
|
||||
func (p *Peer) SupportsComponentNetworkMap() bool {
|
||||
return p.HasCapability(PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func capabilitiesEqual(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
|
||||
@@ -2218,6 +2218,9 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
ID: "test-peer-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
DNSLabel: "test-peer-dns-label",
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -69,8 +69,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
policy.AccountSeqID = existingPolicy.AccountSeqID
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,12 +78,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,21 +47,10 @@ type Checks struct {
|
||||
// AccountID is a reference to the Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
|
||||
|
||||
// Checks is a set of objects that perform the actual checks
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the posture check has been persisted long enough
|
||||
// to have a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip checks that return false here.
|
||||
func (pc *Checks) HasSeqID() bool {
|
||||
return pc != nil && pc.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
@@ -132,12 +121,11 @@ func (*Checks) TableName() string {
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (pc *Checks) Copy() *Checks {
|
||||
checks := &Checks{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
AccountSeqID: pc.AccountSeqID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -51,24 +51,12 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = seq
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
|
||||
@@ -563,61 +563,3 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_AllocatesSeqIDOnCreate verifies that the create path
|
||||
// (no incoming ID) allocates a non-zero AccountSeqID via the
|
||||
// account_seq_counters table.
|
||||
func TestSavePostureChecks_AllocatesSeqIDOnCreate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-allocation-test",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "SavePostureChecks on create must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_PreservesSeqIDOnUpdate verifies the update path does
|
||||
// not reset AccountSeqID even when the caller passes a zero value (REST
|
||||
// handler shape, because the field is `json:"-"`).
|
||||
func TestSavePostureChecks_PreservesSeqIDOnUpdate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-preserve-original",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &posture.Checks{
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.27.0"},
|
||||
},
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, update, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := am.GetPostureChecks(context.Background(), account.Id, created.ID, adminUserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive SavePostureChecks update")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -178,12 +178,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRoute.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -237,7 +231,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
routeToSave.AccountSeqID = oldRoute.AccountSeqID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,506 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
var errRollback = errors.New("intentional rollback")
|
||||
|
||||
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accA = "acc-a"
|
||||
const accB = "acc-b"
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different account starts from 1")
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different entity starts from 1")
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), got, "counter persists across transactions")
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, policies, "test fixture must have policies")
|
||||
|
||||
seen := make(map[uint32]bool)
|
||||
for _, p := range policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
|
||||
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
|
||||
seen[p.AccountSeqID] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
const policyID = "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
|
||||
|
||||
updated := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Enabled: false,
|
||||
Rules: original.Rules,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
|
||||
|
||||
require.NoError(t, store.SavePolicy(ctx, updated))
|
||||
|
||||
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups)
|
||||
|
||||
original := groups[0]
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
updated := &types.Group{
|
||||
ID: original.ID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Issued: original.Issued,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID)
|
||||
|
||||
require.NoError(t, store.UpdateGroup(ctx, updated))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-seqid-test"
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
DNSSettings: types.DNSSettings{},
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
}
|
||||
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
|
||||
require.Len(t, account.Groups, 1, "default 'All' group must be present")
|
||||
require.Len(t, account.Policies, 1, "default policy must be present")
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
|
||||
}
|
||||
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
|
||||
|
||||
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
|
||||
|
||||
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
|
||||
return errRollback
|
||||
}), errRollback)
|
||||
}
|
||||
|
||||
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupSeqs := make(map[string]uint32)
|
||||
policySeqs := make(map[string]uint32)
|
||||
routeSeqs := make(map[route.ID]uint32)
|
||||
nsgSeqs := make(map[string]uint32)
|
||||
resourceSeqs := make(map[string]uint32)
|
||||
routerSeqs := make(map[string]uint32)
|
||||
networkSeqs := make(map[string]uint32)
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
|
||||
groupSeqs[g.ID] = g.AccountSeqID
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
|
||||
policySeqs[p.ID] = p.AccountSeqID
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
|
||||
routeSeqs[r.ID] = r.AccountSeqID
|
||||
}
|
||||
for _, n := range account.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
|
||||
nsgSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
|
||||
resourceSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
|
||||
routerSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture network must have seq>0 after backfill")
|
||||
networkSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, g := range after.Groups {
|
||||
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.Equal(t, networkSeqs[n.ID], n.AccountSeqID, "network %s seq must be preserved", n.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-all-entities"
|
||||
|
||||
addr, err := netip.ParseAddr("8.8.8.8")
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{Identifier: "net-test"},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
|
||||
},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
|
||||
},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
|
||||
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{ID: "n1", AccountID: accountID, Name: "n1"},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: "pc1", AccountID: accountID, Name: "pc1",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, after.Groups, 1)
|
||||
require.Len(t, after.Policies, 1)
|
||||
require.Len(t, after.Routes, 1)
|
||||
require.Len(t, after.NameServerGroups, 1)
|
||||
require.Len(t, after.NetworkResources, 1)
|
||||
require.Len(t, after.NetworkRouters, 1)
|
||||
require.Len(t, after.Networks, 1)
|
||||
require.Len(t, after.PostureChecks, 1)
|
||||
|
||||
for _, g := range after.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "network seq must be allocated")
|
||||
}
|
||||
for _, pc := range after.PostureChecks {
|
||||
require.NotZero(t, pc.AccountSeqID, "posture_check seq must be allocated")
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, after))
|
||||
final, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, r := range final.Routes {
|
||||
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
|
||||
}
|
||||
for _, n := range final.NameServerGroups {
|
||||
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
|
||||
}
|
||||
afterByID := map[string]uint32{}
|
||||
for _, n := range after.Networks {
|
||||
afterByID[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, n := range final.Networks {
|
||||
require.Equal(t, afterByID[n.ID], n.AccountSeqID, "network seq preserved on re-save")
|
||||
}
|
||||
afterPCByID := map[string]uint32{}
|
||||
for _, pc := range after.PostureChecks {
|
||||
afterPCByID[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
for _, pc := range final.PostureChecks {
|
||||
require.Equal(t, afterPCByID[pc.ID], pc.AccountSeqID, "posture_check seq preserved on re-save")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "concurrent-test"
|
||||
const entity = types.AccountSeqEntityPolicy
|
||||
const goroutines = 32
|
||||
|
||||
type result struct {
|
||||
seq uint32
|
||||
err error
|
||||
}
|
||||
results := make(chan result, goroutines)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
<-start
|
||||
var allocated uint32
|
||||
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
|
||||
allocated = seq
|
||||
return err
|
||||
})
|
||||
results <- result{seq: allocated, err: err}
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
|
||||
seen := make(map[uint32]int, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
r := <-results
|
||||
require.NoError(t, r.err, "concurrent allocate must not fail")
|
||||
require.NotZero(t, r.seq, "allocated seq must be non-zero")
|
||||
seen[r.seq]++
|
||||
}
|
||||
|
||||
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
|
||||
for i := uint32(1); i <= goroutines; i++ {
|
||||
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*types.Group{
|
||||
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
|
||||
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
for _, want := range groups {
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
|
||||
}
|
||||
|
||||
groups[0].Name = "g1-renamed"
|
||||
groups[0].AccountSeqID = 0
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
|
||||
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
|
||||
}
|
||||
|
||||
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
maxSeq := uint32(0)
|
||||
for _, p := range existing {
|
||||
if p.AccountSeqID > maxSeq {
|
||||
maxSeq = p.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
|
||||
|
||||
newPolicy := &types.Policy{
|
||||
ID: "bench-new-policy",
|
||||
AccountID: accountID,
|
||||
AccountSeqID: seq,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "bench-new-policy-rule",
|
||||
PolicyID: "bench-new-policy",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
}},
|
||||
}
|
||||
return tx.CreatePolicy(ctx, newPolicy)
|
||||
}))
|
||||
|
||||
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxSeq+1, created.AccountSeqID)
|
||||
}
|
||||
@@ -137,7 +137,6 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
&types.AccountSeqCounter{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -308,10 +307,6 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
|
||||
return fmt.Errorf("assign seq ids: %w", err)
|
||||
}
|
||||
|
||||
result = tx.
|
||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
@@ -663,22 +658,6 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
}
|
||||
|
||||
// CreateGroups creates the given list of groups to the database.
|
||||
// groupUpsertColumns is the explicit allowlist of columns that get updated when
|
||||
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
|
||||
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
|
||||
// handler-built struct) cannot reset the persisted seq id during an upsert.
|
||||
// Keep this in sync with the Group schema in management/server/types/group.go.
|
||||
func groupUpsertColumns() clause.Set {
|
||||
return clause.AssignmentColumns([]string{
|
||||
"account_id",
|
||||
"name",
|
||||
"issued",
|
||||
"integration_ref_id",
|
||||
"integration_ref_integration_type",
|
||||
"resources",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
@@ -688,9 +667,8 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -714,9 +692,8 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -2018,7 +1995,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2028,7 +2005,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
var resources []byte
|
||||
var refID sql.NullInt64
|
||||
var refType sql.NullString
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
if err == nil {
|
||||
if refID.Valid {
|
||||
g.IntegrationReference.ID = int(refID.Int64)
|
||||
@@ -2053,7 +2030,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2062,7 +2039,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
var p types.Policy
|
||||
var checks []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
p.Enabled = enabled.Bool
|
||||
@@ -2080,7 +2057,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
}
|
||||
|
||||
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2090,7 +2067,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
var network, domains, peerGroups, groups, accessGroups []byte
|
||||
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
if err == nil {
|
||||
if keepRoute.Valid {
|
||||
r.KeepRoute = keepRoute.Bool
|
||||
@@ -2132,7 +2109,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2141,7 +2118,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
var n nbdns.NameServerGroup
|
||||
var ns, groups, domains []byte
|
||||
var primary, enabled, searchDomainsEnabled sql.NullBool
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
if err == nil {
|
||||
if primary.Valid {
|
||||
n.Primary = primary.Bool
|
||||
@@ -2177,7 +2154,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2185,7 +2162,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
||||
var c posture.Checks
|
||||
var checksDef []byte
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.AccountSeqID, &c.Name, &c.Description, &checksDef)
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
||||
if err == nil && checksDef != nil {
|
||||
_ = json.Unmarshal(checksDef, &c.Checks)
|
||||
}
|
||||
@@ -2351,7 +2328,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description FROM networks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2368,7 +2345,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2378,7 +2355,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
var peerGroups []byte
|
||||
var masquerade, enabled sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
if err == nil {
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
@@ -2406,7 +2383,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2415,7 +2392,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
|
||||
var r resourceTypes.NetworkResource
|
||||
var prefix []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -3588,262 +3565,6 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must be called inside ExecuteInTransaction so the increment
|
||||
// is serialized with the component insert.
|
||||
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
|
||||
}
|
||||
|
||||
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
return allocateAccountSeqIDReturning(db, accountID, entity)
|
||||
case types.MysqlStoreEngine:
|
||||
return allocateAccountSeqIDMysql(db, accountID, entity)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
|
||||
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
|
||||
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
|
||||
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
|
||||
// branch and returns next_id+1.
|
||||
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, 2)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = account_seq_counters.next_id + 1
|
||||
RETURNING (next_id - 1)
|
||||
`
|
||||
var allocated uint32
|
||||
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
if allocated == 0 {
|
||||
return 0, fmt.Errorf("upsert account seq counter returned 0")
|
||||
}
|
||||
return allocated, nil
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
|
||||
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
|
||||
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
|
||||
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
|
||||
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
|
||||
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
|
||||
// so the follow-up SELECT sees the same value.
|
||||
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const upsertSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, LAST_INSERT_ID(2))
|
||||
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
|
||||
`
|
||||
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
var newNext uint64
|
||||
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
|
||||
return 0, fmt.Errorf("get last insert id: %w", err)
|
||||
}
|
||||
if newNext == 0 {
|
||||
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
|
||||
}
|
||||
return uint32(newNext - 1), nil
|
||||
}
|
||||
|
||||
// assignAccountSeqIDs allocates a per-account integer id for any component on
|
||||
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
|
||||
// the canonical "save the whole account" path produces the same persisted seq
|
||||
// ids that the manager-level Create paths produce. Update flows that go
|
||||
// through SaveAccount preserve existing non-zero values; for those, the
|
||||
// per-entity counter is bumped so subsequent AllocateAccountSeqID calls don't
|
||||
// hand out a colliding id.
|
||||
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
|
||||
maxByEntity := make(map[types.AccountSeqEntity]uint32, 8)
|
||||
bump := func(entity types.AccountSeqEntity, seq uint32) {
|
||||
if seq > maxByEntity[entity] {
|
||||
maxByEntity[entity] = seq
|
||||
}
|
||||
}
|
||||
|
||||
for i := range account.GroupsG {
|
||||
g := account.GroupsG[i]
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
if g.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityGroup, g.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
// Defensive: generateAccountSQLTypes currently aliases the same
|
||||
// *Group pointer into GroupsG and Groups[id] (so this is a no-op
|
||||
// today), but mirror the seq anyway so any future divergence in
|
||||
// how the two collections are populated doesn't silently leave
|
||||
// the canonical map view stale.
|
||||
if original, ok := account.Groups[g.ID]; ok && original != nil && original != g {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPolicy, p.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.AccountSeqID = seq
|
||||
}
|
||||
for i := range account.RoutesG {
|
||||
r := &account.RoutesG[i]
|
||||
if r.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityRoute, r.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.AccountSeqID = seq
|
||||
// Mirror the new seq onto the canonical map view so callers that
|
||||
// hold the same in-memory account post-Save read a consistent
|
||||
// AccountSeqID — without this, components/encoder code would see
|
||||
// 0 for routes saved this transaction until the account is reloaded.
|
||||
if original, ok := account.Routes[r.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for i := range account.NameServerGroupsG {
|
||||
ng := &account.NameServerGroupsG[i]
|
||||
if ng.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNameserverGroup, ng.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ng.AccountSeqID = seq
|
||||
if original, ok := account.NameServerGroups[ng.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkResource, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkRouter, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
if n == nil {
|
||||
continue
|
||||
}
|
||||
if n.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetwork, n.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.AccountSeqID = seq
|
||||
}
|
||||
for _, pc := range account.PostureChecks {
|
||||
if pc == nil {
|
||||
continue
|
||||
}
|
||||
if pc.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPostureCheck, pc.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pc.AccountSeqID = seq
|
||||
}
|
||||
for entity, maxSeq := range maxByEntity {
|
||||
if err := ensureAccountSeqCounter(tx, s.storeEngine, account.Id, entity, maxSeq+1); err != nil {
|
||||
return fmt.Errorf("seed counter for %s: %w", entity, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureAccountSeqCounter raises the per-account counter for entity to at
|
||||
// least target. Used when SaveAccount persists components that already carry
|
||||
// AccountSeqIDs (e.g. test bulk-load from sqlite to postgres, or migrations
|
||||
// running before component data lands) so that the next AllocateAccountSeqID
|
||||
// call returns a fresh id beyond what was just written.
|
||||
func ensureAccountSeqCounter(db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity, target uint32) error {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`
|
||||
// sqlite's UPSERT understands max() but the migration uses GREATEST
|
||||
// for postgres and max() for sqlite. We collapse to dialect-specific
|
||||
// statements only when needed.
|
||||
if engine == types.SqliteStoreEngine {
|
||||
const sqliteSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`
|
||||
return db.Exec(sqliteSQL, accountID, string(entity), target).Error
|
||||
}
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
case types.MysqlStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
default:
|
||||
return fmt.Errorf("unsupported store engine for account_seq counter: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
|
||||
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
|
||||
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
|
||||
@@ -4033,7 +3754,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
@@ -4121,7 +3842,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
@@ -4594,11 +4315,27 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
|
||||
return netRouter, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
result := s.db.Save(router)
|
||||
func (s *SqlStore) CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
if err := s.db.Create(router).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create network router in store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create network router in store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
||||
result := s.db.
|
||||
Select("*").
|
||||
Where(accountAndIDQueryCondition, router.AccountID, router.ID).
|
||||
Updates(router)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save network router to store")
|
||||
log.WithContext(ctx).Errorf("failed to update network router in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update network router in store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewNetworkRouterNotFoundError(router.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -6015,19 +5752,67 @@ func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, acc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
var clusters []proxy.Cluster
|
||||
// GetProxyClusters returns every cluster the account can see (shared
|
||||
// plus its own BYOP), regardless of whether any proxy in the cluster
|
||||
// is currently heartbeating. Online and ConnectedProxies are derived
|
||||
// from the 2-min active window so the dashboard can render offline
|
||||
// clusters distinctly; the 1-hour heartbeat reaper still removes rows
|
||||
// that go quiet for too long.
|
||||
//
|
||||
// AccountOwned is determined by whether any proxy row in the group
|
||||
// carries a non-NULL account_id; the caller maps that to Cluster.Type.
|
||||
// Capability flags are NOT filled here — the handler enriches them via
|
||||
// the per-cluster capability lookups.
|
||||
func (s *SqlStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
activeCutoff := time.Now().Add(-proxyActiveThreshold)
|
||||
|
||||
type clusterRow struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
Online bool
|
||||
AccountOwned bool
|
||||
}
|
||||
|
||||
var rows []clusterRow
|
||||
result := s.db.Model(&proxy.Proxy{}).
|
||||
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
|
||||
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
|
||||
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
|
||||
Select(
|
||||
"MIN(id) AS id, "+
|
||||
"cluster_address AS address, "+
|
||||
// COUNT(CASE WHEN ... THEN 1 END) counts only non-NULL — i.e. only
|
||||
// rows that satisfy the predicate — so it works portably across
|
||||
// sqlite/postgres/mysql without dialect-specific FILTER syntax.
|
||||
"COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS connected_proxies, "+
|
||||
// MAX(CASE …) > 0 expresses BOOL_OR in a way Postgres tolerates
|
||||
// (Postgres can't MAX a boolean column).
|
||||
"MAX(CASE WHEN status = ? AND last_seen > ? THEN 1 ELSE 0 END) > 0 AS online, "+
|
||||
"MAX(CASE WHEN account_id IS NOT NULL THEN 1 ELSE 0 END) > 0 AS account_owned",
|
||||
proxy.StatusConnected, activeCutoff,
|
||||
proxy.StatusConnected, activeCutoff,
|
||||
).
|
||||
Where("account_id IS NULL OR account_id = ?", accountID).
|
||||
Group("cluster_address").
|
||||
Scan(&clusters)
|
||||
Scan(&rows)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get active proxy clusters")
|
||||
log.WithContext(ctx).Errorf("failed to get proxy clusters: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get proxy clusters")
|
||||
}
|
||||
|
||||
clusters := make([]proxy.Cluster, 0, len(rows))
|
||||
for _, r := range rows {
|
||||
c := proxy.Cluster{
|
||||
ID: r.ID,
|
||||
Address: r.Address,
|
||||
Online: r.Online,
|
||||
ConnectedProxies: r.ConnectedProxies,
|
||||
}
|
||||
if r.AccountOwned {
|
||||
c.Type = proxy.ClusterTypeAccount
|
||||
} else {
|
||||
c.Type = proxy.ClusterTypeShared
|
||||
}
|
||||
clusters = append(clusters, c)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
|
||||
109
management/server/store/sql_store_proxy_clusters_test.go
Normal file
109
management/server/store/sql_store_proxy_clusters_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// TestSqlStore_GetProxyClusters_DerivesOnlineAndType guards the
|
||||
// account-visible cluster list against silent regressions in two
|
||||
// dimensions:
|
||||
//
|
||||
// 1. Online derivation: a cluster with one stale and one fresh proxy
|
||||
// is online and counts only the fresh proxy; a cluster whose
|
||||
// proxies all heartbeated outside the 2-min window appears offline
|
||||
// with connected_proxies = 0 (rather than disappearing, which is
|
||||
// what the old query did).
|
||||
// 2. Type derivation: a cluster scoped to the calling account is
|
||||
// reported as `account`; a cluster with account_id IS NULL is
|
||||
// reported as `shared`. Clusters scoped to other accounts must not
|
||||
// leak into the result.
|
||||
//
|
||||
// Capability flags are intentionally not asserted here — they're filled
|
||||
// by the manager (handler) layer from the per-cluster capability
|
||||
// lookups, not by the store query.
|
||||
func TestSqlStore_GetProxyClusters_DerivesOnlineAndType(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
accountID := "acct-clusters"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
|
||||
|
||||
otherAccountID := "acct-other"
|
||||
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, otherAccountID, "user-2", "")))
|
||||
|
||||
acctID := accountID
|
||||
otherID := otherAccountID
|
||||
|
||||
fresh := time.Now().Add(-30 * time.Second)
|
||||
stale := time.Now().Add(-30 * time.Minute)
|
||||
|
||||
mustSave := func(id, cluster string, accID *string, status string, lastSeen time.Time) {
|
||||
require.NoError(t, store.SaveProxy(ctx, &rpproxy.Proxy{
|
||||
ID: id,
|
||||
SessionID: id + "-sess",
|
||||
ClusterAddress: cluster,
|
||||
IPAddress: "10.0.0.1",
|
||||
AccountID: accID,
|
||||
LastSeen: lastSeen,
|
||||
Status: status,
|
||||
}))
|
||||
}
|
||||
|
||||
// shared-mixed: one fresh + one stale proxy → online, connected=1
|
||||
mustSave("p-shared-fresh", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, fresh)
|
||||
mustSave("p-shared-stale", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, stale)
|
||||
|
||||
// shared-offline: only stale proxies → offline, connected=0,
|
||||
// but row must still appear (this is the new semantic — old
|
||||
// query would have dropped it entirely).
|
||||
mustSave("p-shared-off", "shared-offline.netbird.io", nil, rpproxy.StatusConnected, stale)
|
||||
|
||||
// account-online: BYOP cluster owned by acctID, fresh
|
||||
mustSave("p-acct-fresh", "byop.acct.example", &acctID, rpproxy.StatusConnected, fresh)
|
||||
|
||||
// other-account: must not surface for acctID
|
||||
mustSave("p-other", "byop.other.example", &otherID, rpproxy.StatusConnected, fresh)
|
||||
|
||||
clusters, err := store.GetProxyClusters(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
byAddr := map[string]rpproxy.Cluster{}
|
||||
for _, c := range clusters {
|
||||
byAddr[c.Address] = c
|
||||
}
|
||||
|
||||
assert.NotContains(t, byAddr, "byop.other.example",
|
||||
"another account's BYOP cluster must not leak into this account's listing")
|
||||
|
||||
require.Contains(t, byAddr, "shared-mixed.netbird.io")
|
||||
mixed := byAddr["shared-mixed.netbird.io"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeShared, mixed.Type, "shared cluster (account_id IS NULL) must be reported as Type=shared")
|
||||
assert.True(t, mixed.Online, "cluster with a fresh proxy must be online")
|
||||
assert.Equal(t, 1, mixed.ConnectedProxies, "connected_proxies must count only fresh proxies; the stale one should not bump the count")
|
||||
|
||||
require.Contains(t, byAddr, "shared-offline.netbird.io",
|
||||
"offline clusters must still appear so the dashboard can render them — the old GetActiveProxyClusters would have dropped this row, which is the regression this test guards against")
|
||||
offline := byAddr["shared-offline.netbird.io"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeShared, offline.Type)
|
||||
assert.False(t, offline.Online, "no fresh heartbeat → offline")
|
||||
assert.Equal(t, 0, offline.ConnectedProxies, "no fresh proxies → connected_proxies=0")
|
||||
|
||||
require.Contains(t, byAddr, "byop.acct.example")
|
||||
acct := byAddr["byop.acct.example"]
|
||||
assert.Equal(t, rpproxy.ClusterTypeAccount, acct.Type, "BYOP cluster owned by the account must be reported as Type=account")
|
||||
assert.True(t, acct.Online)
|
||||
assert.Equal(t, 1, acct.ConnectedProxies)
|
||||
})
|
||||
}
|
||||
@@ -2399,7 +2399,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
func TestSqlStore_CreateNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
@@ -2410,7 +2410,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveNetworkRouter(context.Background(), netRouter)
|
||||
err = store.CreateNetworkRouter(context.Background(), netRouter)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, netRouter.ID)
|
||||
@@ -2418,6 +2418,39 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
require.Equal(t, netRouter, savedNetRouter)
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
networkID := "ct286bi7qv930dsrrug0"
|
||||
routerID := "ctc20ji7qv9ck2sebc80"
|
||||
|
||||
netRouter := &routerTypes.NetworkRouter{
|
||||
ID: routerID,
|
||||
AccountID: accountID,
|
||||
NetworkID: networkID,
|
||||
Peer: "",
|
||||
PeerGroups: []string{"net-router-grp"},
|
||||
Masquerade: true,
|
||||
Metric: 42,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err = store.UpdateNetworkRouter(context.Background(), netRouter)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, routerID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, netRouter, savedNetRouter)
|
||||
|
||||
// Updating a router under a different account must not match any row.
|
||||
netRouter.AccountID = "non-existent-account"
|
||||
err = store.UpdateNetworkRouter(context.Background(), netRouter)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -220,11 +220,6 @@ type Store interface {
|
||||
GetStoreEngine() types.Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must run inside a transaction so the increment is serialized
|
||||
// with the component insert.
|
||||
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
||||
@@ -233,7 +228,8 @@ type Store interface {
|
||||
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
||||
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
||||
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
||||
SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
||||
DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
|
||||
|
||||
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
||||
@@ -312,7 +308,7 @@ type Store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -476,6 +472,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[nbpeer.Peer](ctx, db, "peer_status_session_started_at", int64(0))
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
||||
},
|
||||
@@ -527,30 +526,6 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[networkTypes.Network](ctx, db, types.AccountSeqEntityNetwork, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[posture.Checks](ctx, db, types.AccountSeqEntityPostureCheck, "id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -310,6 +310,20 @@ func (mr *MockStoreMockRecorder) CreateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroups", reflect.TypeOf((*MockStore)(nil).CreateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// CreateNetworkRouter mocks base method.
|
||||
func (m *MockStore) CreateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateNetworkRouter indicates an expected call of CreateNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) CreateNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNetworkRouter", reflect.TypeOf((*MockStore)(nil).CreateNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// CreatePeerJob mocks base method.
|
||||
func (m *MockStore) CreatePeerJob(ctx context.Context, job *types2.Job) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -380,6 +394,20 @@ func (mr *MockStoreMockRecorder) DeleteAccount(ctx, account interface{}) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccount", reflect.TypeOf((*MockStore)(nil).DeleteAccount), ctx, account)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteCustomDomain mocks base method.
|
||||
func (m *MockStore) DeleteCustomDomain(ctx context.Context, accountID, domainID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -577,20 +605,6 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteRoute mocks base method.
|
||||
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -731,6 +745,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// EphemeralServiceExists mocks base method.
|
||||
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -746,21 +774,6 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID mocks base method.
|
||||
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
|
||||
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1347,21 +1360,6 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2063,6 +2061,21 @@ func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetProxyClusters mocks base method.
|
||||
func (m *MockStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyClusters", ctx, accountID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyClusters indicates an expected call of GetProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2613,6 +2626,36 @@ func (mr *MockStoreMockRecorder) MarkPATUsed(ctx, patID interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPATUsed", reflect.TypeOf((*MockStore)(nil).MarkPATUsed), ctx, patID)
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession mocks base method.
|
||||
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession mocks base method.
|
||||
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPendingJobsAsFailed mocks base method.
|
||||
func (m *MockStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2823,20 +2866,6 @@ func (mr *MockStoreMockRecorder) SaveNetworkResource(ctx, resource interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkResource", reflect.TypeOf((*MockStore)(nil).SaveNetworkResource), ctx, resource)
|
||||
}
|
||||
|
||||
// SaveNetworkRouter mocks base method.
|
||||
func (m *MockStore) SaveNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveNetworkRouter indicates an expected call of SaveNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) SaveNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkRouter", reflect.TypeOf((*MockStore)(nil).SaveNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// SavePAT mocks base method.
|
||||
func (m *MockStore) SavePAT(ctx context.Context, pat *types2.PersonalAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2893,36 +2922,6 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession mocks base method.
|
||||
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession mocks base method.
|
||||
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// SavePolicy mocks base method.
|
||||
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2965,20 +2964,6 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3188,6 +3173,20 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// UpdateNetworkRouter mocks base method.
|
||||
func (m *MockStore) UpdateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateNetworkRouter", ctx, router)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateNetworkRouter indicates an expected call of UpdateNetworkRouter.
|
||||
func (mr *MockStoreMockRecorder) UpdateNetworkRouter(ctx, router interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNetworkRouter", reflect.TypeOf((*MockStore)(nil).UpdateNetworkRouter), ctx, router)
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
4
management/server/testdata/networks.sql
vendored
4
management/server/testdata/networks.sql
vendored
@@ -9,9 +9,13 @@ INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM
|
||||
|
||||
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description');
|
||||
INSERT INTO networks VALUES('secondNetworkId','testAccountId','second-name','second-description');
|
||||
|
||||
CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999);
|
||||
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.65.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO networks VALUES('otherNetworkId','otherAccountId','other-net','other-description');
|
||||
INSERT INTO network_routers VALUES('otherRouterId','otherNetworkId','otherAccountId','otherPeer',NULL,0,1);
|
||||
|
||||
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32');
|
||||
|
||||
@@ -1006,15 +1006,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
// PolicyRuleImpliesLegacySSH reports whether the rule (without an explicit
|
||||
// NetbirdSSH protocol) implicitly authorises SSH because it permits TCP/22 or
|
||||
// TCP/22022 — either by ALL-protocol coverage or by an explicit port/port-range
|
||||
// containing one of those. Exposed for ToComponentSyncResponse so the
|
||||
// envelope-format response mirrors the legacy SshConfig.SshEnabled bit.
|
||||
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return policyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
@@ -16,49 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// GetPeerNetworkMapResult dispatches to either the legacy-NetworkMap path or
|
||||
// the components path based on the peer's capability and the kill switch.
|
||||
// Capable peers (PeerCapabilityComponentNetworkMap) get the raw components
|
||||
// shape — the server skips Calculate() entirely for them, saving CPU
|
||||
// proportional to the number of capable peers in the account. Legacy peers
|
||||
// (or any peer when componentsDisabled is true) get the fully-expanded
|
||||
// NetworkMap as before.
|
||||
func (a *Account) GetPeerNetworkMapResult(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
componentsDisabled bool,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) PeerNetworkMapResult {
|
||||
peer := a.Peers[peerID]
|
||||
if !componentsDisabled && peer != nil && peer.SupportsComponentNetworkMap() {
|
||||
components := a.GetPeerNetworkMapComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs,
|
||||
)
|
||||
// Mirror legacy graceful-degrade: GetPeerNetworkMapFromComponents
|
||||
// returns &NetworkMap{Network: a.Network.Copy()} when components is
|
||||
// nil. Match that floor so the receiving client always sees the
|
||||
// account Network identifier, not a fully-empty envelope.
|
||||
if components == nil {
|
||||
components = &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
return PeerNetworkMapResult{Components: components}
|
||||
}
|
||||
return PeerNetworkMapResult{
|
||||
NetworkMap: a.GetPeerNetworkMapFromComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, metrics, groupIDToUserIDs,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapFromComponents(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
@@ -125,27 +82,15 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
}
|
||||
|
||||
components := &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
NetworkXIDToSeq: make(map[string]uint32, len(a.Networks)),
|
||||
PostureCheckXIDToSeq: make(map[string]uint32, len(a.PostureChecks)),
|
||||
}
|
||||
for _, n := range a.Networks {
|
||||
if n != nil && n.HasSeqID() {
|
||||
components.NetworkXIDToSeq[n.ID] = n.AccountSeqID
|
||||
}
|
||||
}
|
||||
for _, pc := range a.PostureChecks {
|
||||
if pc != nil && pc.HasSeqID() {
|
||||
components.PostureCheckXIDToSeq[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
}
|
||||
|
||||
components.AccountSettings = &AccountSettingsInfo{
|
||||
@@ -308,44 +253,18 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
|
||||
relevantPeerIDs[peerID] = a.GetPeer(peerID)
|
||||
|
||||
peerGroupSet := make(map[string]struct{}, 8)
|
||||
for groupID, group := range a.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
peerGroupSet[groupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
routeAccessControlGroups := make(map[string]struct{})
|
||||
for _, r := range a.Routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
relevant := r.Peer == peerID
|
||||
if !relevant {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant && r.Enabled {
|
||||
for _, groupID := range r.Groups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, groupID := range r.PeerGroups {
|
||||
for _, groupID := range r.Groups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
for _, groupID := range r.Groups {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
if r.Enabled {
|
||||
@@ -566,13 +485,6 @@ func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChe
|
||||
return dest
|
||||
}
|
||||
|
||||
// filterGroupPeers trims each group's Peers slice to only those peers that
|
||||
// also appear in `peers`. Groups whose filtered list is empty are NOT
|
||||
// deleted from the map — they're kept so the components wire encoder can
|
||||
// still resolve seq references from routes/policies/access-control groups
|
||||
// that name them. Calculate() tolerates groups with empty Peers (the inner
|
||||
// loops simply iterate zero times), so retaining them is behaviourally a
|
||||
// no-op for the legacy path that consumes the same NetworkMapComponents.
|
||||
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
|
||||
for groupID, groupInfo := range *groups {
|
||||
filteredPeers := make([]string, 0, len(groupInfo.Peers))
|
||||
@@ -582,7 +494,9 @@ func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
if len(filteredPeers) == 0 {
|
||||
delete(*groups, groupID)
|
||||
} else if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
ng := groupInfo.Copy()
|
||||
ng.Peers = filteredPeers
|
||||
(*groups)[groupID] = ng
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package types
|
||||
|
||||
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
|
||||
type AccountSeqEntity string
|
||||
|
||||
const (
|
||||
AccountSeqEntityPolicy AccountSeqEntity = "policy"
|
||||
AccountSeqEntityGroup AccountSeqEntity = "group"
|
||||
AccountSeqEntityRoute AccountSeqEntity = "route"
|
||||
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
|
||||
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
|
||||
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
|
||||
AccountSeqEntityNetwork AccountSeqEntity = "network"
|
||||
AccountSeqEntityPostureCheck AccountSeqEntity = "posture_check"
|
||||
)
|
||||
|
||||
// AccountSeqCounter tracks the next per-account integer id for a given component
|
||||
// kind. Reads/writes go through the store inside the same transaction as the
|
||||
// component insert so two concurrent inserts cannot collide on the same id.
|
||||
type AccountSeqCounter struct {
|
||||
AccountID string `gorm:"primaryKey;size:255"`
|
||||
Entity string `gorm:"primaryKey;size:32"`
|
||||
NextID uint32 `gorm:"not null;default:1"`
|
||||
}
|
||||
|
||||
// TableName overrides the GORM-derived table name.
|
||||
func (AccountSeqCounter) TableName() string {
|
||||
return "account_seq_counters"
|
||||
}
|
||||
@@ -19,10 +19,6 @@ type Group struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
@@ -45,14 +41,6 @@ type GroupPeer struct {
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the group has been persisted long enough to have a
|
||||
// per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip groups that return false here — otherwise multiple unpersisted
|
||||
// groups would collide on id 0.
|
||||
func (g *Group) HasSeqID() bool {
|
||||
return g != nil && g.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
@@ -86,7 +74,6 @@ func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
AccountSeqID: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
|
||||
@@ -42,17 +42,6 @@ type NetworkMapComponents struct {
|
||||
PostureFailedPeers map[string]map[string]struct{}
|
||||
|
||||
RouterPeers map[string]*nbpeer.Peer
|
||||
|
||||
// NetworkXIDToSeq maps Network.ID (xid) → AccountSeqID. Populated by the
|
||||
// account-side component builder; consumed by the envelope encoder to
|
||||
// translate RoutersMap keys and NetworkResource.NetworkID references
|
||||
// to compact uint32 ids. Legacy Calculate() doesn't consult it.
|
||||
NetworkXIDToSeq map[string]uint32
|
||||
|
||||
// PostureCheckXIDToSeq maps posture.Checks.ID (xid) → AccountSeqID.
|
||||
// Same role as NetworkXIDToSeq, used for PostureFailedPeers keys and
|
||||
// policy SourcePostureChecks references.
|
||||
PostureCheckXIDToSeq map[string]uint32
|
||||
}
|
||||
|
||||
type AccountSettingsInfo struct {
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// wireBenchScales mirrors the scales used by networkmap_benchmark_test.go but
|
||||
// trimmed: encoding+marshal are linear, so we don't need the 30k peer extreme
|
||||
// to see the trend.
|
||||
var wireBenchScales = []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
}
|
||||
|
||||
// populateAccountSeqIDs assigns deterministic AccountSeqIDs to every group and
|
||||
// policy in the account so that the component encoder can reference them. The
|
||||
// scalableTestAccount fixture builds entities by struct literal and skips this
|
||||
// step, but production paths populate the IDs via the store layer.
|
||||
func populateAccountSeqIDs(account *types.Account) {
|
||||
var nextGroupSeq uint32 = 1
|
||||
for _, g := range account.Groups {
|
||||
g.AccountSeqID = nextGroupSeq
|
||||
nextGroupSeq++
|
||||
}
|
||||
var nextPolicySeq uint32 = 1
|
||||
for _, p := range account.Policies {
|
||||
p.AccountSeqID = nextPolicySeq
|
||||
nextPolicySeq++
|
||||
}
|
||||
}
|
||||
|
||||
// assignValidWgKeys overwrites every peer's Key with a valid base64-encoded
|
||||
// 32-byte string. The default scalableTestAccount uses unparsable strings
|
||||
// like "key-peer-0", which makes the components encoder emit a nil WgPubKey
|
||||
// and the legacy encoder ship 10-char placeholders — both shrink the wire
|
||||
// size in unrealistic ways. Production peers always have valid 44-char base64
|
||||
// keys, so any benchmark/breakdown that wants honest numbers must call this.
|
||||
func assignValidWgKeys(account *types.Account) {
|
||||
for _, p := range account.Peers {
|
||||
var raw [32]byte
|
||||
_, _ = rand.Read(raw[:])
|
||||
p.Key = base64.StdEncoding.EncodeToString(raw[:])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireEncode reports per-call ns and the marshaled wire
|
||||
// size for both encoding paths. Run with:
|
||||
//
|
||||
// go test -run=^$ -bench=BenchmarkNetworkMapWireEncode -benchmem ./management/server/types/
|
||||
func BenchmarkNetworkMapWireEncode(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
// Pre-encode once so the size metric is identical for every run inside
|
||||
// the same scale; the b.Loop call only re-runs encode + Marshal.
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
envelopeInput := mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
}
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
envelopeBytes, err := goproto.Marshal(envelope)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("legacy/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(legacyBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
resp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
if _, err := goproto.Marshal(resp.NetworkMap); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("components/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(envelopeBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
if _, err := goproto.Marshal(env); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireSize is a fast snapshot of the wire size by scale
|
||||
// without a tight encode loop. Run with -bench to see one ns/op + bytes per
|
||||
// scale (treat the timing as informational; the sample is one Marshal per
|
||||
// scale, not the full b.N loop).
|
||||
func BenchmarkNetworkMapWireSize(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
envBytes, err := goproto.Marshal(env)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("size/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportMetric(float64(len(legacyBytes)), "legacy_bytes")
|
||||
b.ReportMetric(float64(len(envBytes)), "components_bytes")
|
||||
ratio := float64(len(envBytes)) / float64(len(legacyBytes))
|
||||
b.ReportMetric(ratio, "components/legacy")
|
||||
for range b.N {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkMapWireBreakdown is a one-shot diagnostic: it computes the wire
|
||||
// size attributable to each top-level field of both the legacy NetworkMap and
|
||||
// the components NetworkMapEnvelope at the 5000-peer scale, so the migration
|
||||
// docs can attribute the size reduction to each optimization. Runs only on
|
||||
// demand via -run TestNetworkMapWireBreakdown.
|
||||
func TestNetworkMapWireBreakdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("size diagnostic, skipped with -short")
|
||||
}
|
||||
if os.Getenv("NB_RUN_WIRE_BREAKDOWN") != "1" {
|
||||
t.Skip("set NB_RUN_WIRE_BREAKDOWN=1 to run wire breakdown diagnostic")
|
||||
}
|
||||
|
||||
const peerCount, groupCount = 5000, 100
|
||||
account, validatedPeers := scalableTestAccount(peerCount, groupCount)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyTotal := mustMarshalSize(t, legacyResp.NetworkMap)
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
componentsTotal := mustMarshalSize(t, envelope)
|
||||
|
||||
t.Logf("\n=== LEGACY NetworkMap (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes\n", legacyTotal)
|
||||
|
||||
legacyBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMap
|
||||
}{
|
||||
{"RemotePeers", &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers}},
|
||||
{"OfflinePeers", &proto.NetworkMap{OfflinePeers: legacyResp.NetworkMap.OfflinePeers}},
|
||||
{"FirewallRules", &proto.NetworkMap{FirewallRules: legacyResp.NetworkMap.FirewallRules}},
|
||||
{"Routes", &proto.NetworkMap{Routes: legacyResp.NetworkMap.Routes}},
|
||||
{"RoutesFirewallRules", &proto.NetworkMap{RoutesFirewallRules: legacyResp.NetworkMap.RoutesFirewallRules}},
|
||||
{"DNSConfig", &proto.NetworkMap{DNSConfig: legacyResp.NetworkMap.DNSConfig}},
|
||||
{"PeerConfig", &proto.NetworkMap{PeerConfig: legacyResp.NetworkMap.PeerConfig}},
|
||||
{"SshAuth", &proto.NetworkMap{SshAuth: legacyResp.NetworkMap.SshAuth}},
|
||||
}
|
||||
for _, e := range legacyBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, legacyTotal))
|
||||
}
|
||||
|
||||
full := envelope.GetFull()
|
||||
if full == nil {
|
||||
t.Fatalf("expected full network map envelope payload, got nil")
|
||||
}
|
||||
t.Logf("\n=== COMPONENTS NetworkMapEnvelope (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes (%.1f%% of legacy)\n", componentsTotal, pct(componentsTotal, legacyTotal))
|
||||
|
||||
componentsBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMapComponentsFull
|
||||
}{
|
||||
{"Peers", &proto.NetworkMapComponentsFull{Peers: full.Peers}},
|
||||
{"Policies", &proto.NetworkMapComponentsFull{Policies: full.Policies}},
|
||||
{"Groups", &proto.NetworkMapComponentsFull{Groups: full.Groups}},
|
||||
{"Routes (raw)", &proto.NetworkMapComponentsFull{Routes: full.Routes}},
|
||||
{"NameServerGroups", &proto.NetworkMapComponentsFull{NameserverGroups: full.NameserverGroups}},
|
||||
{"AllDNSRecords", &proto.NetworkMapComponentsFull{AllDnsRecords: full.AllDnsRecords}},
|
||||
{"AccountZones", &proto.NetworkMapComponentsFull{AccountZones: full.AccountZones}},
|
||||
{"NetworkResources", &proto.NetworkMapComponentsFull{NetworkResources: full.NetworkResources}},
|
||||
{"RoutersMap", &proto.NetworkMapComponentsFull{RoutersMap: full.RoutersMap}},
|
||||
{"ResourcePoliciesMap", &proto.NetworkMapComponentsFull{ResourcePoliciesMap: full.ResourcePoliciesMap}},
|
||||
{"GroupIDToUserIDs", &proto.NetworkMapComponentsFull{GroupIdToUserIds: full.GroupIdToUserIds}},
|
||||
{"AllowedUserIDs", &proto.NetworkMapComponentsFull{AllowedUserIds: full.AllowedUserIds}},
|
||||
{"PostureFailedPeers", &proto.NetworkMapComponentsFull{PostureFailedPeers: full.PostureFailedPeers}},
|
||||
{"DNSSettings", &proto.NetworkMapComponentsFull{DnsSettings: full.DnsSettings}},
|
||||
{"PeerConfig", &proto.NetworkMapComponentsFull{PeerConfig: full.PeerConfig}},
|
||||
{"AgentVersions", &proto.NetworkMapComponentsFull{AgentVersions: full.AgentVersions}},
|
||||
}
|
||||
for _, e := range componentsBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, componentsTotal))
|
||||
}
|
||||
|
||||
t.Logf("\n=== Per-PeerCompact average ===")
|
||||
if len(full.Peers) > 0 {
|
||||
t.Logf(" PeerCompact avg: %d bytes/peer", mustMarshalSize(t, &proto.NetworkMapComponentsFull{Peers: full.Peers})/len(full.Peers))
|
||||
}
|
||||
if len(legacyResp.NetworkMap.RemotePeers) > 0 {
|
||||
t.Logf(" RemotePeer avg: %d bytes/peer",
|
||||
mustMarshalSize(t, &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers})/len(legacyResp.NetworkMap.RemotePeers))
|
||||
}
|
||||
|
||||
t.Logf("\n=== FirewallRule expansion footprint ===")
|
||||
t.Logf(" legacy FirewallRules count: %d", len(legacyResp.NetworkMap.FirewallRules))
|
||||
t.Logf(" components Policies count: %d", len(full.Policies))
|
||||
t.Logf(" components Groups count: %d", len(full.Groups))
|
||||
|
||||
totalGroupPeerIdxs := 0
|
||||
for _, g := range full.Groups {
|
||||
totalGroupPeerIdxs += len(g.PeerIndexes)
|
||||
}
|
||||
t.Logf(" components peer-index refs across all groups: %d", totalGroupPeerIdxs)
|
||||
}
|
||||
|
||||
func mustMarshalSize(t *testing.T, m goproto.Message) int {
|
||||
b, err := goproto.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
return len(b)
|
||||
}
|
||||
|
||||
func pct(part, total int) float64 {
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
return 100 * float64(part) / float64(total)
|
||||
}
|
||||
|
||||
// Stops fmt being unused if the breakdown loop above is later commented out.
|
||||
var _ = fmt.Sprintf
|
||||
@@ -1,25 +0,0 @@
|
||||
package types
|
||||
|
||||
// PeerNetworkMapResult is what the network_map controller produces for a
|
||||
// single peer. Exactly one of NetworkMap or Components is populated depending
|
||||
// on the peer's capability:
|
||||
//
|
||||
// - Components-capable peers (PeerCapabilityComponentNetworkMap) get
|
||||
// Components: the raw types.NetworkMapComponents the client decodes and
|
||||
// runs Calculate() on locally. NetworkMap stays nil — the server skips
|
||||
// the expansion entirely.
|
||||
// - Legacy peers (or any peer when the kill switch is set) get NetworkMap:
|
||||
// the fully-expanded view the legacy gRPC path consumes.
|
||||
//
|
||||
// The gRPC layer (ToSyncResponseForPeer) dispatches by which field is
|
||||
// non-nil; callers must not rely on both being set.
|
||||
type PeerNetworkMapResult struct {
|
||||
NetworkMap *NetworkMap
|
||||
Components *NetworkMapComponents
|
||||
}
|
||||
|
||||
// IsComponents reports whether the result carries the components shape.
|
||||
// Use this in preference to direct nil checks on the fields.
|
||||
func (r PeerNetworkMapResult) IsComponents() bool {
|
||||
return r.Components != nil
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// helper: marks the given peer as components-capable.
|
||||
func markCapable(p *nbpeer.Peer) {
|
||||
p.Meta.Capabilities = append(p.Meta.Capabilities, nbpeer.PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_CapablePeerGetsComponents(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false, // componentsDisabled
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
require.True(t, result.IsComponents(), "capable peer must get the components shape")
|
||||
assert.Nil(t, result.NetworkMap)
|
||||
require.NotNil(t, result.Components)
|
||||
assert.Equal(t, "peer-0", result.Components.PeerID)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_LegacyPeerGetsNetworkMap(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
// peer-0 left without the component capability
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false,
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents())
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap, "legacy peer must get a NetworkMap")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_KillSwitchOverridesCapability(t *testing.T) {
|
||||
// Capable peer + componentsDisabled=true → falls back to legacy.
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
true, // componentsDisabled = true (kill switch)
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents(), "kill switch must force legacy NetworkMap path")
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap)
|
||||
}
|
||||
|
||||
func TestPeerNetworkMapResult_IsComponents(t *testing.T) {
|
||||
assert.True(t, types.PeerNetworkMapResult{Components: &types.NetworkMapComponents{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{NetworkMap: &types.NetworkMap{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{}.IsComponents())
|
||||
}
|
||||
@@ -59,10 +59,6 @@ type Policy struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
@@ -79,19 +75,11 @@ type Policy struct {
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the policy has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip policies that return false here.
|
||||
func (p *Policy) HasSeqID() bool {
|
||||
return p != nil && p.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// Copy returns a copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
AccountSeqID: p.AccountSeqID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -59,7 +60,7 @@ func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
|
||||
}
|
||||
@@ -79,7 +80,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "initial sync should not be marked done without flag")
|
||||
}
|
||||
@@ -97,7 +98,7 @@ func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync done flag should be set even without health checker")
|
||||
}
|
||||
|
||||
@@ -25,6 +25,11 @@ type Metrics struct {
|
||||
backendDuration metric.Int64Histogram
|
||||
certificateIssueDuration metric.Int64Histogram
|
||||
|
||||
// Management sync metrics.
|
||||
snapshotSyncDuration metric.Int64Histogram
|
||||
snapshotBatchDuration metric.Int64Histogram
|
||||
addPeerDuration metric.Int64Histogram
|
||||
|
||||
// L4 service-level metrics.
|
||||
l4Services metric.Int64UpDownCounter
|
||||
|
||||
@@ -54,6 +59,9 @@ func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
if err := m.initHTTPMetrics(meter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.initSyncMetrics(meter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.initL4Metrics(meter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -126,6 +134,59 @@ func (m *Metrics) initHTTPMetrics(meter metric.Meter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Metrics) initSyncMetrics(meter metric.Meter) error {
|
||||
var err error
|
||||
|
||||
m.snapshotSyncDuration, err = meter.Int64Histogram(
|
||||
"proxy.sync.snapshot.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration from management connect until the initial snapshot sync is complete"),
|
||||
metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.snapshotBatchDuration, err = meter.Int64Histogram(
|
||||
"proxy.sync.batch.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration to process a single mapping batch during initial snapshot sync"),
|
||||
metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.addPeerDuration, err = meter.Int64Histogram(
|
||||
"proxy.peer.add.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration to add a peer for an account (keygen + gRPC CreateProxyPeer + embed.New)"),
|
||||
metric.WithExplicitBucketBoundaries(10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// RecordSnapshotSyncDuration records the total time from connect to sync-complete.
|
||||
func (m *Metrics) RecordSnapshotSyncDuration(d time.Duration) {
|
||||
m.snapshotSyncDuration.Record(m.ctx, d.Milliseconds())
|
||||
}
|
||||
|
||||
// RecordSnapshotBatchDuration records the time to process one mapping batch during initial sync.
|
||||
func (m *Metrics) RecordSnapshotBatchDuration(d time.Duration) {
|
||||
m.snapshotBatchDuration.Record(m.ctx, d.Milliseconds())
|
||||
}
|
||||
|
||||
// RecordAddPeerDuration records the time to create a new peer for an account.
|
||||
func (m *Metrics) RecordAddPeerDuration(d time.Duration, err error) {
|
||||
result := "success"
|
||||
if err != nil {
|
||||
result = "error"
|
||||
}
|
||||
m.addPeerDuration.Record(m.ctx, d.Milliseconds(), metric.WithAttributes(
|
||||
attribute.String("result", result),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *Metrics) initL4Metrics(meter metric.Meter) error {
|
||||
var err error
|
||||
|
||||
|
||||
@@ -76,6 +76,11 @@ type clientEntry struct {
|
||||
services map[ServiceKey]serviceInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// ready is closed once the client has been fully initialized.
|
||||
// Callers that find a pending entry wait on this channel before
|
||||
// accessing the client. A nil initErr means success.
|
||||
ready chan struct{}
|
||||
initErr error
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
@@ -137,6 +142,11 @@ type NetBird struct {
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
|
||||
// OnAddPeer, when set, is called after AddPeer completes for a new account
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -157,6 +167,9 @@ type skipTLSVerifyContextKey struct{}
|
||||
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple services can share the same client.
|
||||
//
|
||||
// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux
|
||||
// so that concurrent AddPeer calls for different accounts execute in parallel.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
@@ -164,10 +177,23 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
ready := entry.ready
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// If the entry is still being initialized by another goroutine, wait.
|
||||
if ready != nil {
|
||||
select {
|
||||
case <-ready:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if entry.initErr != nil {
|
||||
return fmt.Errorf("peer initialization failed: %w", entry.initErr)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -184,15 +210,43 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
// Insert a placeholder so other goroutines calling AddPeer for the same
|
||||
// account will wait on the ready channel instead of starting a second
|
||||
// client creation.
|
||||
entry = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: si},
|
||||
ready: make(chan struct{}),
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
createStart := time.Now()
|
||||
created, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
if n.OnAddPeer != nil {
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
entry.initErr = err
|
||||
close(entry.ready)
|
||||
|
||||
n.clientsMux.Lock()
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clients[accountID] = entry
|
||||
// Transfer any services that were registered by concurrent AddPeer calls
|
||||
// while we were creating the client.
|
||||
n.clientsMux.Lock()
|
||||
for k, v := range entry.services {
|
||||
created.services[k] = v
|
||||
}
|
||||
created.ready = nil
|
||||
n.clients[accountID] = created
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
close(entry.ready)
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -200,13 +254,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
go n.runClientStartup(ctx, accountID, created.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
// and creates an embedded NetBird client.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
|
||||
@@ -366,7 +366,7 @@ func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, doma
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
func (m *storeBackedServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
300
proxy/process_mappings_bench_test.go
Normal file
300
proxy/process_mappings_bench_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/crowdsec"
|
||||
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
udprelay "github.com/netbirdio/netbird/proxy/internal/udp"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
)
|
||||
|
||||
// latencyMockClient simulates realistic gRPC latency for management calls.
|
||||
type latencyMockClient struct {
|
||||
proto.ProxyServiceClient
|
||||
createPeerDelay time.Duration
|
||||
statusUpdateDelay time.Duration
|
||||
}
|
||||
|
||||
func (m *latencyMockClient) SendStatusUpdate(ctx context.Context, _ *proto.SendStatusUpdateRequest, _ ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
if m.statusUpdateDelay > 0 {
|
||||
select {
|
||||
case <-time.After(m.statusUpdateDelay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *latencyMockClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
if m.createPeerDelay > 0 {
|
||||
select {
|
||||
case <-time.After(m.createPeerDelay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type discardWriter struct{}
|
||||
|
||||
func (discardWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func benchServerWithLatency(b *testing.B, createPeerDelay, statusDelay time.Duration) *Server {
|
||||
b.Helper()
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.FatalLevel)
|
||||
logger.SetOutput(&discardWriter{})
|
||||
|
||||
meter, err := proxymetrics.New(context.Background(), noop.Meter{})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
mgmtClient := &latencyMockClient{
|
||||
createPeerDelay: createPeerDelay,
|
||||
statusUpdateDelay: statusDelay,
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird("bench-proxy", "bench.test",
|
||||
roundtrip.ClientConfig{MgmtAddr: "http://bench.test:9999"},
|
||||
logger, nil, mgmtClient)
|
||||
|
||||
mainRouter := nbtcp.NewRouter(logger, func(accountID types.AccountID) (types.DialContextFunc, error) {
|
||||
return (&net.Dialer{}).DialContext, nil
|
||||
}, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443})
|
||||
|
||||
return &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: mgmtClient,
|
||||
netbird: nb,
|
||||
proxy: proxy.NewReverseProxy(nil, "auto", nil, logger),
|
||||
auth: auth.NewMiddleware(logger, nil, nil),
|
||||
mainRouter: mainRouter,
|
||||
mainPort: 443,
|
||||
meter: meter,
|
||||
hijackTracker: conntrack.HijackTracker{},
|
||||
crowdsecRegistry: crowdsec.NewRegistry("", "", log.NewEntry(logger)),
|
||||
crowdsecServices: make(map[types.ServiceID]bool),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
portRouters: make(map[uint16]*portRouter),
|
||||
svcPorts: make(map[types.ServiceID][]uint16),
|
||||
udpRelays: make(map[types.ServiceID]*udprelay.Relay),
|
||||
}
|
||||
}
|
||||
|
||||
// generateHTTPMappings creates N HTTP-mode mappings with the given update type.
|
||||
// All belong to a single account to share the embedded client.
|
||||
func generateHTTPMappings(n int, updateType proto.ProxyMappingUpdateType) []*proto.ProxyMapping {
|
||||
mappings := make([]*proto.ProxyMapping, n)
|
||||
for i := range n {
|
||||
mappings[i] = &proto.ProxyMapping{
|
||||
Type: updateType,
|
||||
Id: fmt.Sprintf("svc-%d", i),
|
||||
AccountId: "account-1",
|
||||
Domain: fmt.Sprintf("svc-%d.bench.example.com", i),
|
||||
Mode: "http",
|
||||
Path: []*proto.PathMapping{
|
||||
{
|
||||
Path: "/",
|
||||
Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256),
|
||||
},
|
||||
},
|
||||
Auth: &proto.Authentication{},
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
// generateMultiAccountHTTPMappings creates N HTTP-mode CREATED mappings spread
|
||||
// across the given number of accounts. This stresses the AddPeer new-account
|
||||
// path which calls CreateProxyPeer + embed.New per unique account.
|
||||
func generateMultiAccountHTTPMappings(n, accounts int) []*proto.ProxyMapping {
|
||||
mappings := make([]*proto.ProxyMapping, n)
|
||||
for i := range n {
|
||||
mappings[i] = &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: fmt.Sprintf("svc-%d", i),
|
||||
AccountId: fmt.Sprintf("account-%d", i%accounts),
|
||||
Domain: fmt.Sprintf("svc-%d.bench.example.com", i),
|
||||
Mode: "http",
|
||||
Path: []*proto.PathMapping{
|
||||
{
|
||||
Path: "/",
|
||||
Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256),
|
||||
},
|
||||
},
|
||||
Auth: &proto.Authentication{},
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
// generateMixedMappings creates mappings with a realistic distribution:
|
||||
// 70% HTTP create, 15% modify existing, 10% TLS on main port, 5% remove.
|
||||
// All use a single account to avoid embed.New dialing.
|
||||
func generateMixedMappings(n int) []*proto.ProxyMapping {
|
||||
mappings := make([]*proto.ProxyMapping, n)
|
||||
for i := range n {
|
||||
var m *proto.ProxyMapping
|
||||
switch {
|
||||
case i%20 < 14: // 70% HTTP create
|
||||
m = &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: fmt.Sprintf("svc-http-%d", i),
|
||||
AccountId: "account-1",
|
||||
Domain: fmt.Sprintf("svc-%d.bench.example.com", i),
|
||||
Mode: "http",
|
||||
Path: []*proto.PathMapping{
|
||||
{Path: "/", Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256)},
|
||||
{Path: "/api", Target: fmt.Sprintf("http://10.0.%d.%d:8081", (i/256)%256, i%256)},
|
||||
},
|
||||
Auth: &proto.Authentication{},
|
||||
}
|
||||
case i%20 < 17: // 15% modify
|
||||
m = &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED,
|
||||
Id: fmt.Sprintf("svc-http-%d", i%100),
|
||||
AccountId: "account-1",
|
||||
Domain: fmt.Sprintf("svc-%d.bench.example.com", i%100),
|
||||
Mode: "http",
|
||||
Path: []*proto.PathMapping{
|
||||
{Path: "/", Target: fmt.Sprintf("http://10.1.%d.%d:8080", (i/256)%256, i%256)},
|
||||
},
|
||||
Auth: &proto.Authentication{},
|
||||
}
|
||||
case i%20 < 19: // 10% TLS passthrough on main port
|
||||
m = &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: fmt.Sprintf("svc-tls-%d", i),
|
||||
AccountId: "account-1",
|
||||
Domain: fmt.Sprintf("tls-%d.bench.example.com", i),
|
||||
Mode: "tls",
|
||||
ListenPort: 443,
|
||||
Path: []*proto.PathMapping{
|
||||
{Path: "/", Target: fmt.Sprintf("10.2.%d.%d:443", (i/256)%256, i%256)},
|
||||
},
|
||||
}
|
||||
default: // 5% remove
|
||||
m = &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: fmt.Sprintf("svc-http-%d", i%50),
|
||||
AccountId: "account-1",
|
||||
Domain: fmt.Sprintf("svc-%d.bench.example.com", i%50),
|
||||
Mode: "http",
|
||||
}
|
||||
}
|
||||
mappings[i] = m
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
const (
|
||||
createPeerLatency = 100 * time.Millisecond
|
||||
statusUpdateLatency = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
// BenchmarkProcessMappings_HTTPCreate_SingleAccount benchmarks the initial sync
|
||||
// scenario: N HTTP mappings all on a single account. Only the first mapping
|
||||
// triggers CreateProxyPeer (100ms gRPC). The rest just register with the
|
||||
// existing client. This is the "best case" production path.
|
||||
func BenchmarkProcessMappings_HTTPCreate_SingleAccount(b *testing.B) {
|
||||
for _, n := range []int{100, 1000, 5000} {
|
||||
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
|
||||
mappings := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED)
|
||||
for range b.N {
|
||||
s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency)
|
||||
s.processMappings(b.Context(), mappings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkProcessMappings_HTTPCreate_MultiAccount benchmarks the worst-case
|
||||
// initial sync: every mapping belongs to a different account, so each one
|
||||
// triggers a full CreateProxyPeer gRPC round-trip (100ms) + embed.New.
|
||||
// With 500 accounts this serializes to ~50s of blocking I/O.
|
||||
func BenchmarkProcessMappings_HTTPCreate_MultiAccount(b *testing.B) {
|
||||
for _, tc := range []struct {
|
||||
mappings int
|
||||
accounts int
|
||||
}{
|
||||
{100, 10},
|
||||
{100, 50},
|
||||
{1000, 50},
|
||||
{1000, 200},
|
||||
{3000, 500},
|
||||
} {
|
||||
b.Run(fmt.Sprintf("mappings=%d/accounts=%d", tc.mappings, tc.accounts), func(b *testing.B) {
|
||||
mappings := generateMultiAccountHTTPMappings(tc.mappings, tc.accounts)
|
||||
for range b.N {
|
||||
s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency)
|
||||
s.processMappings(b.Context(), mappings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkProcessMappings_Mixed benchmarks a realistic mixed workload
|
||||
// of creates, modifies, TLS, and removes with production-like latency.
|
||||
// TLS mappings call SendStatusUpdate (50ms each), serialized.
|
||||
func BenchmarkProcessMappings_Mixed(b *testing.B) {
|
||||
for _, n := range []int{100, 1000, 5000} {
|
||||
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
|
||||
mappings := generateMixedMappings(n)
|
||||
for range b.N {
|
||||
s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency)
|
||||
creates := generateHTTPMappings(100, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED)
|
||||
s.processMappings(b.Context(), creates)
|
||||
s.processMappings(b.Context(), mappings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkProcessMappings_ModifyOnly benchmarks bulk modification of
|
||||
// already-registered mappings (no new peers needed, no gRPC).
|
||||
func BenchmarkProcessMappings_ModifyOnly(b *testing.B) {
|
||||
for _, n := range []int{100, 1000, 5000} {
|
||||
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
|
||||
creates := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED)
|
||||
modifies := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
for range b.N {
|
||||
s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency)
|
||||
s.processMappings(b.Context(), creates)
|
||||
s.processMappings(b.Context(), modifies)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkProcessMappings_NoLatency measures pure CPU/allocation overhead
|
||||
// with zero I/O latency for profiling purposes.
|
||||
func BenchmarkProcessMappings_NoLatency(b *testing.B) {
|
||||
for _, n := range []int{1000, 5000} {
|
||||
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
|
||||
mappings := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED)
|
||||
for range b.N {
|
||||
s := benchServerWithLatency(b, 0, 0)
|
||||
s.processMappings(b.Context(), mappings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
278
proxy/server.go
278
proxy/server.go
@@ -32,9 +32,11 @@ import (
|
||||
"go.opentelemetry.io/otel/sdk/metric"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
@@ -282,6 +284,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
@@ -938,6 +941,9 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
// syncSupported tracks whether management supports SyncMappings.
|
||||
// Starts true; set to false on first Unimplemented error.
|
||||
syncSupported := true
|
||||
initialSyncDone := false
|
||||
|
||||
operation := func() error {
|
||||
@@ -949,36 +955,25 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established")
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
|
||||
streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone)
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
}
|
||||
@@ -995,56 +990,187 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error {
|
||||
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
return &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: s.proxyCapabilities(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established (GetMappingUpdate)")
|
||||
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sync stream: %w", err)
|
||||
}
|
||||
|
||||
// Send init message.
|
||||
if err := stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Init{
|
||||
Init: &proto.SyncMappingsInit{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: s.proxyCapabilities(),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established (SyncMappings)")
|
||||
|
||||
return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func isSyncUnimplemented(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
return ok && st.Code() == codes.Unimplemented
|
||||
}
|
||||
|
||||
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case err != nil:
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
if err := stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send ack: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Shutting down.
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := mappingClient.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
// Mapping connection gracefully terminated by server.
|
||||
return nil
|
||||
case err != nil:
|
||||
// Something has gone horribly wrong, return and hope the parent retries the connection.
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotTracker accumulates service IDs during the initial snapshot and
|
||||
// finalises sync state when the complete flag arrives.
|
||||
type snapshotTracker struct {
|
||||
done *bool
|
||||
connectTime time.Time
|
||||
snapshotIDs map[types.ServiceID]struct{}
|
||||
}
|
||||
|
||||
func (s *Server) newSnapshotTracker(done *bool, connectTime time.Time) *snapshotTracker {
|
||||
var ids map[types.ServiceID]struct{}
|
||||
if !*done {
|
||||
ids = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
return &snapshotTracker{done: done, connectTime: connectTime, snapshotIDs: ids}
|
||||
}
|
||||
|
||||
func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings []*proto.ProxyMapping, syncComplete bool, batchStart time.Time) {
|
||||
if *t.done {
|
||||
return
|
||||
}
|
||||
|
||||
if s.meter != nil {
|
||||
s.meter.RecordSnapshotBatchDuration(time.Since(batchStart))
|
||||
}
|
||||
|
||||
for _, m := range mappings {
|
||||
t.snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
|
||||
if !syncComplete {
|
||||
return
|
||||
}
|
||||
|
||||
s.reconcileSnapshot(ctx, t.snapshotIDs)
|
||||
t.snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*t.done = true
|
||||
if s.meter != nil {
|
||||
s.meter.RecordSnapshotSyncDuration(time.Since(t.connectTime))
|
||||
}
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
@@ -1067,6 +1193,8 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
s.ensurePeers(ctx, mappings)
|
||||
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
@@ -1100,6 +1228,60 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
|
||||
}
|
||||
}
|
||||
|
||||
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
|
||||
// CREATED mappings. Peers for different accounts are created concurrently,
|
||||
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
|
||||
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
// Collect one representative mapping per account that needs a new peer.
|
||||
type peerReq struct {
|
||||
accountID types.AccountID
|
||||
svcKey roundtrip.ServiceKey
|
||||
authToken string
|
||||
svcID types.ServiceID
|
||||
}
|
||||
seen := make(map[types.AccountID]struct{})
|
||||
var reqs []peerReq
|
||||
for _, m := range mappings {
|
||||
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
|
||||
continue
|
||||
}
|
||||
accountID := types.AccountID(m.GetAccountId())
|
||||
if _, ok := seen[accountID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
if s.netbird.HasClient(accountID) {
|
||||
continue
|
||||
}
|
||||
reqs = append(reqs, peerReq{
|
||||
accountID: accountID,
|
||||
svcKey: s.serviceKeyForMapping(m),
|
||||
authToken: m.GetAuthToken(),
|
||||
svcID: types.ServiceID(m.GetId()),
|
||||
})
|
||||
}
|
||||
|
||||
if len(reqs) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(reqs))
|
||||
for _, r := range reqs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": r.accountID,
|
||||
"service_id": r.svcID,
|
||||
"error": err,
|
||||
}).Warn("failed to pre-create peer for account")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// addMapping registers a service mapping and starts the appropriate relay or routes.
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -139,7 +140,7 @@ func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||
}
|
||||
|
||||
syncDone := true // sync already completed in a previous stream
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2,
|
||||
@@ -185,7 +186,7 @@ func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||
|
||||
@@ -218,7 +219,7 @@ func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone, time.Time{})
|
||||
assert.Error(t, err)
|
||||
assert.False(t, syncDone)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user